# Tutorials
> This bundle contains all pages in the Tutorials section.
> Source: https://www.union.ai/docs/v2/union/tutorials/

=== PAGE: https://www.union.ai/docs/v2/union/tutorials ===

# Tutorials

> **📝 Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.

This section contains tutorials that showcase relevant use cases and provide step-by-step instructions on how to implement various features using Flyte and Union.

### **Automatic prompt engineering**

Easily run prompt optimization with real-time observability, traceability, and automatic recovery.

### **GPU-accelerated climate modeling**

Run ensemble atmospheric simulations on H200 GPUs with multi-source data ingestion and real-time extreme event detection.

### **Run LLM-generated code**

Securely execute and iterate on LLM-generated code using a code agent with error reflection and retry logic.

### **Deep research**

Build an agentic workflow for deep research with multi-step reasoning and evaluation.

### **Distributed LLM pretraining**

Pretrain large language models at scale with PyTorch Lightning, FSDP, and H200 GPUs, featuring streaming data and real-time metrics.

### **Fine-tuning a vision-language model with a frozen backbone**

Adapt Qwen2.5-VL to occluded image classification by training a 10K-parameter adapter with multi-node DeepSpeed, automatic recovery, and live training dashboards.

### **Hyperparameter optimization**

Run large-scale HPO experiments with zero manual tracking, deterministic results, and automatic recovery.

### **Multi-agent trading simulation**

A multi-agent trading simulation, modeling how agents within a firm might interact, strategize, and make trades collaboratively.

### **Text-to-SQL**

Learn how to turn natural language questions into SQL queries with Flyte and LlamaIndex, and explore prompt optimization in practice.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/distributed-pretraining ===

# Distributed LLM pretraining

When training large models, infrastructure should not be the hardest part. The real work is in the model architecture, the data, and the hyperparameters. In practice, though, teams often spend weeks just trying to get distributed training to run reliably.

And when it breaks, it usually breaks in familiar ways: out-of-memory crashes, corrupted checkpoints, data loaders that silently fail, or runs that hang with no obvious explanation.

Most distributed training tutorials focus on PyTorch primitives. This one focuses on getting something that actually ships. We go into the technical details, such as how FSDP shards parameters, why gradient clipping behaves differently at scale, and how streaming datasets reduce memory pressure, but always with the goal of building a system that works in production.

Real training jobs need more than a training loop. They need checkpointing, fault tolerance, data streaming, visibility into what’s happening, and the ability to recover from failures. In this tutorial, we build all of that using Flyte, without having to stand up or manage any additional infrastructure.

> [!NOTE]
> Full code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/pretraining/train.py).

## Overview

We're going to pretrain a GPT-2 style language model from scratch. This involves training on raw text data starting from randomly initialized weights, rather than fine-tuning or adapting a pretrained model. This is the same process used to train the original GPT-2, LLaMA, and most other foundation models.

The model learns by predicting the next token. Given "The cat sat on the", it learns to predict "mat". Do this billions of times across terabytes of text, and the model develops surprisingly sophisticated language understanding. That's pretraining.

The challenge is scale. A 30B parameter model doesn't fit on a single GPU. The training dataset, [SlimPajama](https://huggingface.co/datasets/cerebras/SlimPajama-627B) in our case, is 627 billion tokens. Training runs last for days or even weeks. To make this work, you need:

- **Distributed training**: Split the model across multiple GPUs using [FSDP (Fully Sharded Data Parallel)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- **Data streaming**: Pull training data on-demand instead of downloading terabytes upfront
- **Checkpointing**: Save progress regularly so a failure doesn’t wipe out days of compute
- **Observability**: See what's happening inside a multi-day training run

We’ll build a Flyte pipeline that takes care of all of this, using three tasks with clearly defined responsibilities:

1. **Data preparation**: Tokenizes your dataset and converts it to MDS (MosaicML Data Shard) format for streaming. This Flyte task is cached, so it only needs to be run once and can be reused across runs.
2. **Distributed training**: Runs FSDP across 8 H200 GPUs. Flyte's `Elastic` plugin handles the distributed setup. Checkpoints upload to S3 automatically via Flyte's `File` abstraction.
3. **Real-time reporting**: Streams loss curves and training metrics to Flyte Reports, a live dashboard integrated into the Flyte UI.

Why three separate tasks? Flyte makes this separation efficient:

- **Caching**: The data preparation step runs once. On subsequent runs, Flyte skips it entirely.
- **Resource isolation**: Training uses expensive H200 GPUs only while actively training, while the driver runs on inexpensive CPU instances.
- **Fault boundaries**: If training fails, the data preparation step does not re-run. Training can resume directly from the most recent checkpoint.

## Implementation

Let's walk through the code. We'll start with the infrastructure setup, build the model, then wire everything together into a pipeline.

### Setting up the environment

Every distributed training job needs a consistent environment across all nodes. Flyte handles this with container images:

```
import logging
import math
import os
from pathlib import Path
from typing import Optional

import flyte
import flyte.report
import lightning as L
import numpy as np
import torch
import torch.nn as nn
from flyte.io import Dir, File
from flyteplugins.pytorch.task import Elastic
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The imports tell the story: `flyte` for orchestration, `flyte.report` for live dashboards, `lightning` for training loop management, and `Elastic` from Flyte's PyTorch plugin. This last one is key as it configures PyTorch's distributed launch without you writing any distributed setup code.

```
NUM_NODES = 1
DEVICES_PER_NODE = 8
VOCAB_SIZE = (
    50257  # GPT-2 BPE tokenizer vocabulary size (constant across all model sizes)
)
N_POSITIONS = 2048  # Maximum sequence length (constant across all model sizes)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

These constants define the distributed topology. We're using 1 node with 8 GPUs, but you can scale this up by changing `NUM_NODES`. The vocabulary size (50,257 tokens) and sequence length (2,048 tokens) match GPT-2's [Byte Pair Encoding (BPE) tokenizer](https://huggingface.co/learn/llm-course/en/chapter6/5).

```
image = flyte.Image.from_debian_base(
    name="distributed_training_h200"
).with_pip_packages(
    "transformers==4.57.3",
    "datasets==4.4.1",
    "tokenizers==0.22.1",
    "huggingface-hub==0.34.0",
    "mosaicml-streaming>=0.7.0",
    "pyarrow==22.0.0",
    "flyteplugins-pytorch>=2.0.0b33",
    "torch==2.9.1",
    "lightning==2.5.6",
    "tensorboard==2.20.0",
    "sentencepiece==0.2.1",
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

Flyte builds this container automatically when the pipeline is run. All dependencies required for distributed training, including PyTorch, Lightning, the streaming library, and NCCL for GPU communication, are baked in. There's no Dockerfile to maintain and no "works on my machine" debugging.

### Declaring resource requirements

Different parts of the pipeline need different resources. Data tokenization needs CPU and memory. Training needs GPUs. The driver just coordinates. Flyte's `TaskEnvironment` lets you declare exactly what each task needs:

```
data_loading_env = flyte.TaskEnvironment(
    name="data_loading_h200",
    image=image,
    resources=flyte.Resources(cpu=5, memory="28Gi", disk="100Gi"),
    env_vars={
        "HF_DATASETS_CACHE": "/tmp/hf_cache",  # Cache directory for datasets
        "TOKENIZERS_PARALLELISM": "true",  # Enable parallel tokenization
    },
    cache="auto",
)

distributed_llm_training_env = flyte.TaskEnvironment(
    name="distributed_llm_training_h200",
    image=image,
    resources=flyte.Resources(
        cpu=64,
        memory="512Gi",
        gpu=f"H200:{DEVICES_PER_NODE}",
        disk="1Ti",
        shm="16Gi",  # Explicit shared memory for NCCL communication
    ),
    plugin_config=Elastic(nnodes=NUM_NODES, nproc_per_node=DEVICES_PER_NODE),
    env_vars={
        "TORCH_DISTRIBUTED_DEBUG": "INFO",
        "NCCL_DEBUG": "WARN",
    },
    cache="auto",
)

driver_env = flyte.TaskEnvironment(
    name="llm_training_driver",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    cache="auto",
    depends_on=[data_loading_env, distributed_llm_training_env],
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

Let's break down the training environment, since this is where most of the complexity lives:

- **`gpu=f"H200:{DEVICES_PER_NODE}"`**: Flyte provisions exactly 8 H200 GPUs. These have 141GB of memory each, enough to train 30B+ parameter models with FSDP.
- **`shm="16Gi"`**: This allocates explicit shared memory. NCCL (NVIDIA's communication library) uses shared memory for inter-GPU communication on the same node. Without this, you'll see cryptic errors like "NCCL error: unhandled system error", which can be difficult to debug.
- **`Elastic(nnodes=NUM_NODES, nproc_per_node=DEVICES_PER_NODE)`**: This is Flyte's integration with PyTorch's elastic launch. It handles process spawning (one process per GPU), rank assignment (each process knows its ID), and environment setup (master address, world size). This replaces the boilerplate typically written in shell scripts.

The `driver_env` is intentionally lightweight, using 2 CPUs and 4 GB of memory. Its role is limited to orchestrating tasks and passing data between them, so allocating GPUs here would be unnecessary.

### Model configurations

Training a 1.5B model uses different hyperparameters than training a 65B model. Rather than hardcoding values, we define presets:

```
MODEL_CONFIGS = {
    "1.5B": {
        "n_embd": 2048,
        "n_layer": 24,
        "n_head": 16,
        "batch_size": 8,
        "learning_rate": 6e-4,
        "checkpoint_every_n_steps": 10,
        "report_every_n_steps": 5,
        "val_check_interval": 100,
    },  # Good for testing and debugging
    "30B": {
        "n_embd": 6656,
        "n_layer": 48,
        "n_head": 52,
        "batch_size": 1,
        "learning_rate": 1.6e-4,
        "checkpoint_every_n_steps": 7500,
        "report_every_n_steps": 200,
        "val_check_interval": 1000,
    },
    "65B": {
        "n_embd": 8192,
        "n_layer": 80,
        "n_head": 64,
        "batch_size": 1,
        "learning_rate": 1.5e-4,
        "checkpoint_every_n_steps": 10000,
        "report_every_n_steps": 250,
        "val_check_interval": 2000,
    },
}

def get_model_config(model_size: str) -> dict:
    if model_size not in MODEL_CONFIGS:
        available = ", ".join(MODEL_CONFIGS.keys())
        raise ValueError(f"Unknown model size: {model_size}. Available: {available}")

    return MODEL_CONFIGS[model_size]
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

A few things to notice:

- **Batch size decreases with model size**: For a fixed GPU memory budget, larger models consume more memory for parameters, optimizer state, and activations, leaving less room for per-GPU batch size. For example, a 1.5B parameter model may fit a batch size of 8 per GPU, while a 65B model may only fit a batch size of 1. This is typically compensated for using gradient accumulation to maintain a larger effective batch size.
- **Learning rate decreases with model size**: Larger models are more sensitive to optimization instability and typically require lower learning rates. The values here follow empirical best practices used in large-scale language model training, informed by work such as the [Chinchilla study](https://arxiv.org/pdf/2203.15556) on compute-optimal scaling.
- **Checkpoint frequency increases with model size**: Checkpointing a 65B model is expensive (the checkpoint is huge). We do it less often but make sure we don't lose too much progress if something fails.

The 1.5B config is good for testing your setup before committing to a serious training run.

### Building the GPT model

Now for the model itself. We're building a GPT-2 style decoder-only transformer from scratch.

First, the configuration class:

```
class GPTConfig:
    """Configuration for GPT model."""

    def __init__(
        self,
        vocab_size: int = VOCAB_SIZE,
        n_positions: int = N_POSITIONS,
        n_embd: int = 2048,
        n_layer: int = 24,
        n_head: int = 16,
        n_inner: Optional[int] = None,
        activation_function: str = "gelu_new",
        dropout: float = 0.1,
        layer_norm_epsilon: float = 1e-5,
    ):
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_inner = n_inner if n_inner is not None else 4 * n_embd
        self.activation_function = activation_function
        self.dropout = dropout
        self.layer_norm_epsilon = layer_norm_epsilon
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The key architectural parameters:

- **`n_embd`**: The hidden (embedding) dimension. Larger values increase model capacity but also increase memory and compute requirements.
- **`n_layer`**: The number of transformer blocks. Model depth strongly influences expressiveness and performance.
- **`n_head`**: The number of attention heads. Each head can attend to different patterns or relationships in the input.
- **`n_inner`**: The hidden dimension of the feed-forward network (MLP), typically set to 4x the embedding dimension.

Next, we define a single transformer block:

```
class GPTBlock(nn.Module):
    """Transformer block with causal self-attention."""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.attn = nn.MultiheadAttention(
            config.n_embd,
            config.n_head,
            dropout=config.dropout,
            batch_first=True,
        )
        self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

        # Get activation function from config
        ACT_FNS = {
            "gelu": nn.GELU(),
            "gelu_new": nn.GELU(approximate="tanh"),  # GPT-2 uses approximate GELU
            "relu": nn.ReLU(),
            "silu": nn.SiLU(),
            "swish": nn.SiLU(),  # SiLU = Swish
        }
        act_fn = ACT_FNS.get(config.activation_function, nn.GELU())

        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, config.n_inner),
            act_fn,
            nn.Linear(config.n_inner, config.n_embd),
            nn.Dropout(config.dropout),
        )

    def forward(self, x, causal_mask, key_padding_mask=None):
        x_normed = self.ln_1(x)

        # Self-attention with causal and padding masks
        attn_output, _ = self.attn(
            x_normed,  # query
            x_normed,  # key
            x_normed,  # value
            attn_mask=causal_mask,  # Causal mask: (seq_len, seq_len)
            key_padding_mask=key_padding_mask,  # Padding mask: (batch, seq_len)
            need_weights=False,
        )
        x = x + attn_output

        # MLP with residual
        x = x + self.mlp(self.ln_2(x))
        return x
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

Each block has two sub-layers: causal self-attention and a feed-forward MLP. The causal mask ensures the model can only attend to previous tokens in the sequence, so it can't "cheat" by looking at the answer. This is what makes it *autoregressive*.

The full `GPTModel` class (see the complete code) stacks these blocks and adds token and positional embeddings. One important detail is that the input token embedding matrix is shared with the output projection layer (often called [weight tying](https://mbrenndoerfer.com/writing/weight-tying-shared-embeddings-transformers)). This reduces the number of parameters by roughly 50 million for typical vocabulary sizes and often leads to better generalization and more stable training.

### The Lightning training module

PyTorch Lightning handles the training loop boilerplate. We wrap our model in a `LightningModule` that defines how to train it:

```
class GPTPreTrainingModule(L.LightningModule):
    """PyTorch Lightning module for GPT pre-training."""

    def __init__(
        self,
        vocab_size: int = 50257,
        n_positions: int = 2048,
        n_embd: int = 2048,
        n_layer: int = 24,
        n_head: int = 16,
        learning_rate: float = 6e-4,
        weight_decay: float = 0.1,
        warmup_steps: int = 2000,
        max_steps: int = 100000,
    ):
        super().__init__()
        self.save_hyperparameters()

        config = GPTConfig(
            vocab_size=vocab_size,
            n_positions=n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
        )
        self.model = GPTModel(config)

    def forward(self, input_ids, attention_mask=None):
        return self.model(input_ids, attention_mask)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The `save_hyperparameters()` call is important because it stores all constructor arguments in the checkpoint. This allows the model to be reloaded later without having to manually reconstruct the original configuration.

The training and validation steps implement standard causal language modeling, where the model is trained to predict the next token given all previous tokens in the sequence.

```
    def training_step(self, batch, _batch_idx):
        # Convert int32 to int64 (long) - MDS stores as int32 but PyTorch expects long
        input_ids = batch["input_ids"].long()
        labels = batch["labels"].long()

        # Get attention mask if present (optional, for padded sequences)
        # attention_mask: 1 = real token, 0 = padding
        # Note: Current data pipeline creates fixed-length sequences without padding,
        # so attention_mask is not present. If using padded sequences, ensure:
        #   - Padded positions in labels are set to -100 (ignored by cross_entropy)
        #   - attention_mask marks real tokens (1) vs padding (0)
        attention_mask = batch.get("attention_mask", None)

        # Forward pass (causal mask is created internally in GPTModel)
        logits = self(input_ids, attention_mask=attention_mask)

        # Shift logits and labels for causal language modeling
        # Before shift: labels[i] = input_ids[i]
        # After shift: predict input_ids[i+1] from input_ids[:i+1]
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # Calculate loss
        loss = nn.functional.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100,
        )

        # Log loss
        self.log(
            "train/loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )

        # Calculate and log perplexity only on epoch (exp is costly, less frequent is fine)
        perplexity = torch.exp(torch.clamp(loss, max=20.0))
        self.log(
            "train/perplexity",
            perplexity,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )

        return loss

    def validation_step(self, batch, _batch_idx):
        # Convert int32 to int64 (long) - MDS stores as int32 but PyTorch expects long
        input_ids = batch["input_ids"].long()
        labels = batch["labels"].long()

        # Get attention mask if present (optional, for padded sequences)
        attention_mask = batch.get("attention_mask", None)

        # Forward pass (causal mask is created internally in GPTModel)
        logits = self(input_ids, attention_mask=attention_mask)

        # Shift logits and labels
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # Calculate loss
        loss = nn.functional.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100,
        )

        # Log loss
        self.log("val/loss", loss, prog_bar=True, sync_dist=True)

        # Calculate and log perplexity (exp is costly, but validation is infrequent so OK)
        perplexity = torch.exp(torch.clamp(loss, max=20.0))
        self.log("val/perplexity", perplexity, prog_bar=True, sync_dist=True)

        return loss
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The model performs a forward pass with a causal (autoregressive) mask created internally, ensuring each token can only attend to earlier positions. To align predictions with targets, the logits and labels are shifted so that the representation at position `i` is used to predict token `i + 1`.

Loss is computed using cross-entropy over the shifted logits and labels. Training loss and perplexity are logged during execution, with metrics synchronized across distributed workers.

The optimizer setup is where a lot of training stability comes from:

```
    def configure_optimizers(self):
        # Separate parameters into weight decay and no weight decay groups
        decay_params = []
        no_decay_params = []

        for param in self.model.parameters():
            if param.requires_grad:
                # 1D parameters (biases, LayerNorm) don't get weight decay
                # 2D+ parameters (weight matrices) get weight decay
                if param.ndim == 1:
                    no_decay_params.append(param)
                else:
                    decay_params.append(param)

        optimizer_grouped_parameters = [
            {"params": decay_params, "weight_decay": self.hparams.weight_decay},
            {"params": no_decay_params, "weight_decay": 0.0},
        ]

        # AdamW optimizer
        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.learning_rate,
            betas=(0.9, 0.95),
            eps=1e-8,
        )

        # Learning rate scheduler: warmup + cosine decay
        # Warmup: linear increase from 0 to 1.0 over warmup_steps
        # Decay: cosine decay from 1.0 to 0.0 over remaining steps
        def lr_lambda(current_step):
            if current_step < self.hparams.warmup_steps:
                # Linear warmup
                return float(current_step) / float(max(1, self.hparams.warmup_steps))

            # Cosine decay after warmup
            progress = (current_step - self.hparams.warmup_steps) / max(
                1, self.hparams.max_steps - self.hparams.warmup_steps
            )
            # Cosine annealing from 1.0 to 0.0 (returns float, not tensor)
            return 0.5 * (1.0 + math.cos(progress * math.pi))

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

Two important choices here:

1. **Separate weight decay groups**: We only apply weight decay to the weight matrices, not to biases or LayerNorm parameters. This follows the original BERT paper and is now standard practice, as regularizing biases and normalization parameters does not improve performance and can be harmful.
2. **Cosine learning rate schedule with warmup**: We start with a low learning rate, ramp up linearly during warmup (helps stabilize early training when gradients are noisy), then decay following a cosine curve. This schedule outperforms constant or step decay for transformer training.

### Checkpointing for fault tolerance

Training a 30B-parameter model for 15,000 steps can take days. Hardware failures and spot instance preemptions are inevitable, which makes checkpointing essential.

```
class S3CheckpointCallback(L.Callback):
    """
    Periodically upload checkpoints to S3 for durability and resumption.

    This ensures checkpoints are safely stored in remote storage even if
    the training job is interrupted or the instance fails.
    """

    def __init__(self, checkpoint_dir: Path, upload_every_n_steps: int):
        super().__init__()
        self.checkpoint_dir = checkpoint_dir
        self.upload_every_n_steps = upload_every_n_steps
        self.last_uploaded_step = -1

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        """Upload checkpoint to S3 every N steps."""
        if trainer.global_rank != 0:
            return  # Only upload from rank 0

        current_step = trainer.global_step

        # Upload every N steps (aligns with ModelCheckpoint's every_n_train_steps)
        if (
            current_step % self.upload_every_n_steps == 0
            and current_step > self.last_uploaded_step
            and current_step > 0
        ):
            try:
                # Find the most recent checkpoint file
                checkpoint_files = list(self.checkpoint_dir.glob("*.ckpt"))
                if not checkpoint_files:
                    print("No checkpoint files found to upload")
                    return

                # Get the latest checkpoint (by modification time)
                latest_checkpoint = max(
                    checkpoint_files, key=lambda p: p.stat().st_mtime
                )

                # Upload the checkpoint file directly to S3 using File.from_local_sync
                checkpoint_file = File.from_local_sync(str(latest_checkpoint))
                print(f"Checkpoint uploaded to S3 at: {checkpoint_file.path}")

                self.last_uploaded_step = current_step
            except Exception as e:
                print(f"Warning: Failed to upload checkpoint to S3: {e}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

This callback runs every `N` training steps and uploads the checkpoint to durable storage. The key line is `File.from_local_sync()` which is a Flyte abstraction for uploading files. There are no blob store credentials to manage and no bucket paths to hardcode. Flyte automatically uses the storage backend configured for your cluster.

The callback only runs on rank 0. In distributed training, all 8 GPUs have identical model states (that's the point of data parallelism). Having all of them upload the same checkpoint would be wasteful and could cause race conditions.

When you restart a failed run, pass the checkpoint via `resume_checkpoint` so training resumes exactly where it left off, including the same step count, optimizer state, and learning rate schedule position.

### Real-time metrics with Flyte Reports

Multi-day training runs need observability. Is the loss decreasing? Did training diverge? Is the learning rate schedule behaving correctly? Flyte Reports let you build live dashboards directly in the UI:

```
class FlyteReportingCallback(L.Callback):
    """Custom Lightning callback to report training metrics to Flyte Report."""

    def __init__(self, report_every_n_steps: int = 100):
        super().__init__()
        self.report_every_n_steps = report_every_n_steps
        self.metrics_history = {
            "step": [],
            "train_loss": [],
            "learning_rate": [],
            "val_loss": [],
            "val_perplexity": [],
        }
        self.initialized_report = False
        self.last_logged_step = -1

    def on_train_start(self, trainer, pl_module):
        """Initialize the live dashboard on training start."""
        if trainer.global_rank == 0 and not self.initialized_report:
            self._initialize_report()
            self.initialized_report = True
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The `_initialize_report` method (see complete code) creates an HTML/JavaScript dashboard with interactive charts. The callback then calls `flyte.report.log()` every `N` steps to push new metrics. The charts update in real-time so you can watch your loss curve descend while training runs.

There is no need to deploy Grafana, configure Prometheus, or keep a TensorBoard server running. Using `flyte.report.log()` is sufficient to get live training metrics directly in the Flyte UI.

![Metrics viz](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/distributed-llm-pretraining/metrics.png)

### Streaming data at scale

Training datasets are massive. SlimPajama contains 627 billion tokens and spans hundreds of gigabytes even when compressed. Downloading the entire dataset to each training node before starting would take hours and waste storage.

Instead, we convert the data to MDS (MosaicML Data Shard) format and stream it during training:

```
@data_loading_env.task
async def load_and_prepare_streaming_dataset(
    dataset_name: str,
    dataset_config: Optional[str],
    max_length: int,
    train_split: str,
    val_split: Optional[str],
    max_train_samples: Optional[int],
    max_val_samples: Optional[int],
    shard_size_mb: int,
) -> Dir:
    """Tokenize dataset and convert to MDS format for streaming."""
    from datasets import load_dataset
    from streaming import MDSWriter
    from transformers import GPT2TokenizerFast

    output_dir = Path("/tmp/streaming_dataset")
    output_dir.mkdir(parents=True, exist_ok=True)

    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    # MDS schema: what each sample contains
    columns = {
        "input_ids": "ndarray:int32",
        "labels": "ndarray:int32",
    }
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

This task does three things:

1. **Tokenizes the text** using GPT-2's BPE tokenizer
2. **Concatenates documents** into fixed-length sequences (no padding waste)
3. **Writes shards** to storage in a format optimized for streaming

The task returns a Flyte `Dir` object, which is a reference to the output location. It's not the data itself, just a pointer. When the training task receives this `Dir`, it streams shards on-demand rather than downloading everything upfront.

Flyte caches this task automatically. Run the pipeline twice with the same dataset config, and Flyte skips tokenization entirely on the second run. Change the dataset or sequence length, and it re-runs.

### Distributed training with FSDP

Now we get to the core: actually training the model across multiple GPUs. FSDP is what makes this possible for large models.

```
@distributed_llm_training_env.task(report=True)
def train_distributed_llm(
    prepared_dataset: Dir,
    resume_checkpoint: Optional[Dir],
    vocab_size: int,
    n_positions: int,
    n_embd: int,
    n_layer: int,
    n_head: int,
    batch_size: int,
    num_workers: int,
    max_steps: int,
    learning_rate: float,
    weight_decay: float,
    warmup_steps: int,
    use_fsdp: bool,
    checkpoint_upload_steps: int,
    checkpoint_every_n_steps: int,
    report_every_n_steps: int,
    val_check_interval: int,
    grad_accumulation_steps: int = 1,
) -> Optional[Dir]:
    # ... setup code ...
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

Notice `report=True` on the task decorator. It enables Flyte Reports for this specific task.

The training task receives the prepared dataset as a `Dir` and streams data directly from storage:

```
    # StreamingDataset streams shards from the remote Flyte storage on-demand
    # It automatically detects torch.distributed context
    # and shards data across GPUs - each rank gets different data automatically
    train_dataset = StreamingDataset(
        remote=f"{remote_path}/train",  # Remote MDS shard location
        local=str(local_cache / "train"),  # Local cache for downloaded shards
        shuffle=True,  # Shuffle samples
        shuffle_algo="naive",  # Shuffling algorithm
        batch_size=batch_size,  # Used for shuffle buffer sizing
    )

    # Create validation StreamingDataset if it exists
    val_dataset = None
    try:
        val_dataset = StreamingDataset(
            remote=f"{remote_path}/validation",
            local=str(local_cache / "validation"),
            shuffle=False,  # No shuffling for validation
            batch_size=batch_size,
        )
        print(
            f"Validation dataset initialized with streaming from: {remote_path}/validation"
        )
    except Exception as e:
        print(f"No validation dataset found: {e}")

    # Create data loaders
    # StreamingDataset handles distributed sampling internally by detecting
    # torch.distributed.get_rank() and torch.distributed.get_world_size()
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        drop_last=True,  # Drop incomplete batches for distributed training
        collate_fn=mds_collate_fn,  # Handle read-only arrays
    )

    # Create validation loader if validation dataset exists
    val_loader = None
    if val_dataset is not None:
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            persistent_workers=True,
            drop_last=False,
            collate_fn=mds_collate_fn,
        )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

`prepared_dataset.path` provides the remote storage path for the dataset. MosaicML's `StreamingDataset` automatically shards data across GPUs so that each rank sees different samples, without requiring a manual distributed sampler. The credentials are already in the environment because Flyte set them up.

FSDP is where the memory magic happens. Instead of each GPU holding a full copy of the model (like Distributed Data Parallel (DDP)), FSDP shards the parameters, gradients, and optimizer states across all GPUs. Each GPU only holds 1/8th of the model. When a layer needs to run, FSDP all-gathers the full parameters, runs the computation, then discards them.

```
    # Configure distributed strategy
    if use_fsdp:
        from torch.distributed.fsdp.wrap import ModuleWrapPolicy

        strategy = FSDPStrategy(
            auto_wrap_policy=ModuleWrapPolicy([GPTBlock]),
            activation_checkpointing_policy=None,
            cpu_offload=False,  # H200 has 141GB - no CPU offload needed
            state_dict_type="full",
            sharding_strategy="FULL_SHARD",
            process_group_backend="nccl",
        )
    else:
        from lightning.pytorch.strategies import DDPStrategy

        strategy = DDPStrategy(process_group_backend="nccl")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

We wrap at the `GPTBlock` level because each transformer block becomes an FSDP unit. This balances communication overhead (more units = more all-gathers) against memory savings (smaller units = more granular sharding).

One subtle detail: gradient clipping. With FSDP, gradients are sharded across ranks, so computing a global gradient norm would require an expensive all-reduce operation. Instead of norm-based clipping, we use value-based gradient clipping, which clamps each individual gradient element to a fixed range. This can be done independently on each rank with no coordination overhead and is commonly used for large-scale FSDP training.

```
    # Initialize trainer
    trainer = L.Trainer(
        strategy=strategy,
        accelerator="gpu",
        devices=DEVICES_PER_NODE,
        num_nodes=NUM_NODES,
        # Training configuration
        max_steps=max_steps,
        precision="bf16-mixed",  # BFloat16 for better numerical stability
        # Optimization
        gradient_clip_val=1.0,
        gradient_clip_algorithm=(
            "value" if use_fsdp else "norm"
        ),  # FSDP requires 'value', DDP can use 'norm'
        accumulate_grad_batches=grad_accumulation_steps,
        # Logging and checkpointing
        callbacks=callbacks,
        log_every_n_steps=report_every_n_steps,
        val_check_interval=val_check_interval,
        # Performance
        benchmark=True,
        deterministic=False,
        # Enable gradient checkpointing for memory efficiency
        enable_checkpointing=True,
        use_distributed_sampler=False,  # StreamingDataset handles distributed sampling
    )

    # Train the model (resume from checkpoint if provided)
    trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)

    # Print final results
    if trainer.global_rank == 0:
        if val_loader is not None:
            print(
                f"Final validation loss: {trainer.callback_metrics.get('val/loss', 0.0):.4f}"
            )
            print(
                f"Final validation perplexity: {trainer.callback_metrics.get('val/perplexity', 0.0):.4f}"
            )
        print(f"Checkpoints saved to: {checkpoint_dir}")

        return Dir.from_local_sync(output_dir)

    return None
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The trainer configuration brings together all the pieces we've discussed:

- **`precision="bf16-mixed"`**: BFloat16 mixed precision training. BF16 has the same dynamic range as FP32 (unlike FP16), so you don't need loss scaling. This is the standard choice for modern GPU training.
- **`gradient_clip_val=1.0`**: Clips gradients to prevent exploding gradients during training. Combined with value-based clipping for FSDP compatibility.
- **`accumulate_grad_batches`**: Accumulates gradients over multiple forward passes before updating weights. This lets us hit a larger effective batch size than what fits in GPU memory.
- **`val_check_interval`**: How often to run validation. For long training runs, you don't want to validate every epoch — that would be too infrequent. Instead, validate every `N` training steps.
- **`use_distributed_sampler=False`**: We disable Lightning's built-in distributed sampler because `StreamingDataset` handles data sharding internally. Using both would cause conflicts.
- **`benchmark=True`**: Enables cuDNN autotuning. PyTorch will benchmark different convolution algorithms on the first batch and pick the fastest one for your specific input sizes.

The trainer then calls `fit()` with the model, data loaders, and optionally a checkpoint path to resume from.

### Tying it together

The pipeline task orchestrates everything:

```
@driver_env.task
async def distributed_llm_pipeline(
    model_size: str,
    dataset_name: str = "Salesforce/wikitext",
    dataset_config: str = "wikitext-103-raw-v1",
    max_length: int = 2048,
    max_train_samples: Optional[int] = 10000,
    max_val_samples: Optional[int] = 1000,
    max_steps: int = 100000,
    resume_checkpoint: Optional[Dir] = None,
    checkpoint_upload_steps: int = 1000,
    # Optional overrides (if None, uses model preset defaults)
    batch_size: Optional[int] = None,
    learning_rate: Optional[float] = None,
    use_fsdp: bool = True,
) -> Optional[Dir]:
    # Get model configuration
    model_config = get_model_config(model_size)

    # Use preset values if not overridden
    actual_batch_size = (
        batch_size if batch_size is not None else model_config["batch_size"]
    )
    actual_learning_rate = (
        learning_rate if learning_rate is not None else model_config["learning_rate"]
    )

    # Step 1: Load and prepare streaming dataset
    prepared_dataset = await load_and_prepare_streaming_dataset(
        dataset_name=dataset_name,
        dataset_config=dataset_config,
        max_length=max_length,
        train_split="train",
        val_split="validation",
        max_train_samples=max_train_samples,
        max_val_samples=max_val_samples,
        shard_size_mb=64,  # 64MB shards
    )

    # Step 2: Run distributed training
    if resume_checkpoint is not None:
        print("\nStep 2: Resuming distributed training from checkpoint...")
    else:
        print("\nStep 2: Starting distributed training from scratch...")

    target_global_batch = 256
    world_size = NUM_NODES * DEVICES_PER_NODE
    effective_per_step = world_size * actual_batch_size
    grad_accumulation_steps = max(
        1, math.ceil(target_global_batch / max(1, effective_per_step))
    )

    checkpoint_dir = train_distributed_llm(
        prepared_dataset=prepared_dataset,
        resume_checkpoint=resume_checkpoint,
        vocab_size=VOCAB_SIZE,
        n_positions=N_POSITIONS,
        n_embd=model_config["n_embd"],
        n_layer=model_config["n_layer"],
        n_head=model_config["n_head"],
        batch_size=actual_batch_size,
        num_workers=8,
        max_steps=max_steps,
        learning_rate=actual_learning_rate,
        weight_decay=0.1,
        warmup_steps=500,
        use_fsdp=use_fsdp,
        checkpoint_upload_steps=checkpoint_upload_steps,
        checkpoint_every_n_steps=model_config["checkpoint_every_n_steps"],
        report_every_n_steps=model_config["report_every_n_steps"],
        val_check_interval=model_config["val_check_interval"],
        grad_accumulation_steps=grad_accumulation_steps,
    )

    return checkpoint_dir
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

The flow is straightforward: load the configuration, prepare the data, and run training. Flyte automatically manages the execution graph so data preparation runs first and training waits until it completes. If data preparation is cached from a previous run, training starts immediately.

The gradient accumulation calculation is worth noting. We want a global batch size of 256 (this affects training dynamics), but each GPU can only fit a small batch. With 8 GPUs and batch size 1 each, we need 32 accumulation steps to hit 256.

## Running the pipeline

With everything defined, running is simple:

```
if __name__ == "__main__":
    flyte.init_from_config()

    run = flyte.run(
        distributed_llm_pipeline,
        model_size="30B",
        dataset_name="cerebras/SlimPajama-627B",
        dataset_config=None,
        max_length=2048,
        max_train_samples=5_000_000,
        max_val_samples=50_000,
        max_steps=15_000,
        use_fsdp=True,
        checkpoint_upload_steps=1000,
    )

    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/pretraining/train.py*

This configuration is designed for testing and demonstration. Notice `max_train_samples=5_000_000` — that's 5 million samples from a dataset with 627 billion tokens. A tiny fraction, enough to verify everything works without burning through compute.

For a real pretraining run, you would remove this limit by setting `max_train_samples=None`, or increase it significantly. You would also increase `max_steps` to match your compute budget, likely scale to multiple nodes by setting `NUM_NODES=4` or higher, and allocate more resources. The rest of the pipeline remains unchanged.

```bash
flyte create config --endpoint <FLYTE_OR_UNION_ENDPOINT> --project <PROJECT_NAME> --domain <DOMAIN_NAME> --builder remote
uv run train.py
```

When you run this, Flyte:

1. **Builds the container** (cached after first run)
2. **Schedules data prep** on CPU nodes
3. **Waits for data prep** (or skips if cached)
4. **Provisions H200 nodes** and launches distributed training
5. **Streams logs and metrics** to the UI in real-time

Open the Flyte UI to observe the workflow execution. The data preparation task completes first, followed by the training task spinning up. As training begins, the Flyte Reports dashboard starts plotting loss curves. If anything goes wrong, the logs are immediately available in the UI.

![Training Log](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/distributed-llm-pretraining/logs.png)

If training fails due to an out-of-memory error, a GPU driver error, or a hardware issue, check the logs, fix the problem, and restart the run with `resume_checkpoint` pointing to the most recent checkpoint. Training resumes from where it left off. Flyte tracks the full execution history, so it is easy to see exactly what happened.

## Going further

If you've run through this tutorial, here's where to go next depending on what you're trying to do:

**You want to train on your own data.** The data prep task accepts any HuggingFace dataset with a `text` column. If your data isn't on HuggingFace, you can modify `load_and_prepare_streaming_dataset` to read from S3, local files, or any other source. The key is getting your data into MDS format. Once it's there, the streaming and sharding just works. For production training, look at SlimPajama, [RedPajama](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T), or [The Pile](https://huggingface.co/datasets/EleutherAI/pile) as starting points.

**You want to scale to more GPUs.** Bump `NUM_NODES` and Flyte handles the rest. The main thing to watch is the effective batch size. As you add more GPUs, you may want to reduce gradient accumulation steps to keep the same global batch size, or increase them if you want to experiment with larger batches.

**Your training keeps failing.** Add `retries=3` to your task decorator for automatic retry on transient failures. This handles spot instance preemption, temporary network issues, and the occasional GPU that decides to stop working. Combined with checkpointing, you get fault-tolerant training that can survive most infrastructure hiccups. For persistent failures, the Flyte UI logs are your friend as they capture stdout/stderr from all ranks.

**You want better visibility into what's happening.** We're actively working on surfacing GPU driver logs (xid/sxid errors), memory utilization breakdowns, and NCCL communication metrics directly in the Flyte UI. If you're hitting issues that the current logs don't explain, reach out. Your feedback helps us prioritize what observability features to build next!

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/qwen-vl-finetuning ===

# Fine-tuning a vision-language model with a frozen backbone

Large vision-language models like Qwen2.5-VL are remarkably capable out of the box. But adapting one to a specialized task raises an immediate question: do you really need to update 3 billion parameters?

Usually, no. The **frozen backbone pattern** is a practical alternative: keep all pretrained weights frozen and train only a small, task-specific adapter inserted before the vision encoder. The adapter learns to transform its input in a way that makes the frozen model perform well on your task without touching the underlying billions of parameters. The result is faster training, lower memory pressure, and a much smaller set of weights to store and version.

This tutorial makes that pattern concrete. We take a partially-occluded image classification task — CIFAR-10 images with random black rectangles covering 22–45% of the frame — and train a tiny Conv2d adapter to "see through" the occlusion before the frozen VLM processes it. The adapter has approximately **10,500 trainable parameters**. The backbone has 3 billion.

The machine learning is interesting, but the real focus here is on shipping a production-grade training pipeline:

- **Multi-node distributed training** across 2 nodes × 4 GPUs using PyTorch Elastic and DeepSpeed Stage 2
- **Automatic fault tolerance**: checkpoints upload to object storage after every validation epoch; if training fails, the pipeline returns the last known-good checkpoint instead of crashing
- **Live observability**: a streaming HTML dashboard in the Flyte UI updates in real-time as training runs, no separate monitoring infrastructure required
- **Cached data preparation**: dataset processing runs once and is reused across all reruns
- **Clean task isolation**: each stage runs with exactly the resources it needs, nothing more

> [!NOTE]
> Full code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning).

## Overview

The pipeline has four tasks with clearly defined responsibilities:

1. **Dataset preparation** (`prepare_occlusion_dataset`): Downloads CIFAR-10, applies random occlusions, and writes image manifests as streaming JSONL files to object storage. Runs on CPU and is cached, so it only runs once regardless of how many times you rerun the pipeline with the same config.
2. **Multi-node training** (`train_qwen_adapter_multinode`): Runs PyTorch Lightning with DeepSpeed Stage 2 across 2 nodes × 4 L40s GPUs. Only the adapter trains; the 3B backbone stays frozen.
3. **Evaluation** (`evaluate_qwen_adapter`): Loads the saved adapter, runs inference on validation examples, and produces a predictions report. Runs on a single GPU.
4. **Driver** (`qwen_vl_multinode_deepspeed`): The pipeline entry point. Orchestrates the three tasks above, manages WandB initialization, handles recovery from training failures, and produces a final HTML report in the Flyte UI.

Why this separation? It mirrors how production pipelines should be structured. Data prep is cheap and deterministic so we cache it. Training is expensive and failure-prone so we isolate it with fault tolerance. Evaluation needs different hardware than training. The driver is pure coordination, so it gets minimal resources.

## Implementation

### Setting up the environment

Different tasks need different compute. Flyte's `TaskEnvironment` is how you declare exactly what each task needs.

First, define the container images. Training needs a full CUDA stack with ML libraries, driver compatibility, and DeepSpeed's build tools:

```
gpu_image = (
    flyte.Image.from_base("nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04")
    .clone(name="qwen_vl_multinode_deepspeed", python_version=(3, 13), extendable=True)
    .with_apt_packages("build-essential")
    .with_pip_packages(
        "torch==2.9.1",
        "torchvision==0.24.1",
        "lightning==2.6.1",
        "transformers==4.57.3",
        "deepspeed==0.18.8",
        "datasets==4.4.1",
        "pillow==11.3.0",
        "flyteplugins-pytorch>=2.0.11",
        "flyteplugins-jsonl>=2.0.11",
        "flyteplugins-wandb>=2.0.11",
    )
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/config.py*

`from_base` starts from the official NVIDIA CUDA image, giving you NCCL, cuDNN, and the right driver headers out of the box. `with_apt_packages("build-essential")` is required because DeepSpeed compiles CUDA kernels at first use and without build tools, it silently falls back to slower CPU implementations. The non-GPU image for data preparation and orchestration is much lighter:

```
non_gpu_image = flyte.Image.from_debian_base(
    name="qwen_vl_multinode_deepspeed_non_gpu"
).with_pip_packages(
    "flyteplugins-pytorch>=2.0.11",
    "flyteplugins-jsonl>=2.0.11",
    "flyteplugins-wandb>=2.0.11",
    "lightning==2.6.1",
    "datasets==4.4.1",
    "pillow==11.3.0",
    "torchvision==0.24.1",
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/config.py*

With images defined, each task gets its own resource declaration:

```
dataset_env = flyte.TaskEnvironment(
    name="qwen_vl_dataset_prep",
    image=non_gpu_image,
    resources=flyte.Resources(cpu=5, memory="15Gi"),
    cache="auto",
)

training_env = flyte.TaskEnvironment(
    name="qwen_vl_multinode_training",
    image=gpu_image,
    resources=flyte.Resources(
        cpu=42,
        memory="256Gi",
        gpu=f"L40s:{DEVICES_PER_NODE}",
        shm="16Gi",
    ),
    plugin_config=Elastic(nnodes=NUM_NODES, nproc_per_node=DEVICES_PER_NODE),
    secrets=[
        flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")
    ],  # TODO: update with your own secret key
    env_vars={
        "TORCH_DISTRIBUTED_DEBUG": "INFO",
        "NCCL_DEBUG": "WARN",
        "TOKENIZERS_PARALLELISM": "false",
        "CUDA_HOME": "/usr/local/cuda",
        "DS_SKIP_CUDA_CHECK": "1",
    },
)

evaluation_env = flyte.TaskEnvironment(
    name="qwen_vl_adapter_eval",
    image=gpu_image,
    resources=flyte.Resources(cpu=16, memory="64Gi", gpu="L40s:1"),
    cache="auto",
)

driver_env = flyte.TaskEnvironment(
    name="qwen_vl_multinode_driver",
    image=non_gpu_image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    depends_on=[dataset_env, training_env, evaluation_env],
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/config.py*

A few things worth noting here:

- **`Elastic(nnodes=2, nproc_per_node=4)`**: Flyte's integration with PyTorch's elastic launch. It handles process spawning (one process per GPU), rank assignment, and distributed environment setup — master address, world size, rendezvous — without any shell scripting or manual `torchrun` invocations.
- **`shm="16Gi"`**: Shared memory is required for NCCL inter-GPU communication on the same node. Without it, you'll see cryptic errors from the communication library when training starts.
- **`cache="auto"`**: The dataset preparation task is cached by input hash. Running the pipeline twice with the same hyperparameters skips it entirely on the second run.
- **`depends_on`**: The driver task declares that each worker image must finish building before it starts, ensuring containers are ready before the driver begins orchestrating.
- **`secrets`**: The WandB API key is injected from Flyte's secret store as an environment variable. No credentials in code.

All training hyperparameters flow through a single typed dataclass:

```
@dataclass
class Config:
    model_name: str = DEFAULT_MODEL_NAME
    image_size: int = IMAGE_SIZE
    max_train_samples: int = 1024
    max_val_samples: int = 256
    epochs: int = 8
    per_device_batch_size: int = 1
    target_global_batch_size: int = 16
    learning_rate: float = 2e-4
    weight_decay: float = 1e-2
    reconstruction_loss_weight: float = 0.35
    report_every_n_steps: int = 10
    num_workers: int = 4
    max_length: int = 512
    eval_examples: int = 16
    train_occlusion_min: float = 0.22
    train_occlusion_max: float = 0.42
    eval_occlusion_min: float = 0.28
    eval_occlusion_max: float = 0.45
    seed: int = 7

    def to_dict(self) -> dict:
        return asdict(self)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/config.py*

Using a dataclass rather than scattered constants or argparse arguments means the full config is serializable, can be stored in artifact metadata alongside the model checkpoint, and flows cleanly as a typed input between tasks. The `to_dict()` method serializes it for WandB logging.

### Preparing the dataset

The dataset task handles everything: downloading CIFAR-10, generating occlusions, and writing the manifests.

```
@dataset_env.task
async def prepare_occlusion_dataset(config: Config) -> DatasetArtifacts:
    from PIL import Image
    from torchvision.datasets import CIFAR10
    from flyte.io import Dir
    from flyteplugins.jsonl import JsonlFile
    import random

    rng = random.Random(config.seed)

    images_dir = Path("/tmp/qwen_vl_occlusion_images")
    train_images_dir = images_dir / "train" / "images"
    val_images_dir = images_dir / "validation" / "images"
    train_images_dir.mkdir(parents=True, exist_ok=True)
    val_images_dir.mkdir(parents=True, exist_ok=True)

    prompt = (
        "The image may be partially occluded. "
        "Answer with exactly one CIFAR-10 class label: "
        + ", ".join(CLASS_NAMES)
        + ". What is the main object?"
    )

    async def export_split(
        dataset,
        split_name: str,
        limit: int,
        local_image_dir: Path,
        occ_min: float,
        occ_max: float,
    ):
        out = JsonlFile.new_remote(f"{split_name}_manifest.jsonl")
        async with out.writer() as writer:
            for idx in range(limit):
                pil_image, label_idx = dataset[idx]
                resized = pil_image.resize(
                    (config.image_size, config.image_size),
                    resample=Image.Resampling.BICUBIC,
                )
                rel_path = f"{split_name}/images/{split_name}-{idx:05d}.png"
                resized.save(local_image_dir / f"{split_name}-{idx:05d}.png")
                occlusion = build_occlusion_box(
                    width=config.image_size,
                    height=config.image_size,
                    rng=rng,
                    min_fraction=occ_min,
                    max_fraction=occ_max,
                )
                await writer.write(
                    {
                        "image_path": rel_path,
                        "label": CLASS_NAMES[label_idx],
                        "label_index": int(label_idx),
                        "prompt": prompt,
                        "occlusion": occlusion,
                    }
                )
        return out

    train_dataset = CIFAR10(root="/tmp/cifar10", train=True, download=True)
    val_dataset = CIFAR10(root="/tmp/cifar10", train=False, download=True)

    train_manifest = await export_split(
        train_dataset,
        "train",
        config.max_train_samples,
        train_images_dir,
        config.train_occlusion_min,
        config.train_occlusion_max,
    )
    val_manifest = await export_split(
        val_dataset,
        "validation",
        config.max_val_samples,
        val_images_dir,
        config.eval_occlusion_min,
        config.eval_occlusion_max,
    )

    return DatasetArtifacts(
        train_manifest=train_manifest,
        val_manifest=val_manifest,
        images=await Dir.from_local(str(images_dir)),
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/data.py*

Each image gets a randomly-placed black rectangle. The occlusion covers 22–42% of the image area during training and 28–45% during evaluation. The occlusion is deliberately harder at eval time to test how robust the adapter is. The bounding box coordinates are written into each manifest record alongside the image path and ground-truth label, so the training task can reconstruct the binary occlusion mask as the adapter's fourth input channel.

Two Flyte primitives handle data persistence without any manual storage management:

- **`JsonlFile.new_remote()`** opens a streaming writer that writes directly to remote object storage. The training task reads records back via `jf.iter_records_sync()`, so no local file paths and S3 credentials to manage.
- **`Dir.from_local()`** uploads the local images directory to object storage and returns a typed handle. The training task downloads it to a local path via `Dir.download_sync()`.

Because `cache="auto"` is set on this task, dataset preparation runs once. Subsequent reruns with the same config skip it entirely.

### The adapter

Here's the entire trainable component of the model with `~10,500` parameters:

```
class ResidualOcclusionAdapter(nn.Module):
    def __init__(self, hidden_channels: int = 32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(4, hidden_channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(hidden_channels, 3, kernel_size=1),
            nn.Tanh(),
        )
        self.gate = nn.Parameter(torch.tensor(0.10))

    def forward(
        self, pixel_values: torch.Tensor, occlusion_mask: torch.Tensor
    ) -> torch.Tensor:
        if pixel_values.ndim != 4:
            raise ValueError(
                "ResidualOcclusionAdapter expects dense image tensors with shape "
                f"(B, C, H, W), but received {tuple(pixel_values.shape)}."
            )
        if occlusion_mask.ndim == 3:
            occlusion_mask = occlusion_mask.unsqueeze(1)
        adapter_input = torch.cat(
            [pixel_values, occlusion_mask.to(pixel_values.dtype)],
            dim=1,
        )
        residual = self.net(adapter_input)
        return pixel_values + torch.tanh(self.gate) * residual
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/model.py*

The adapter takes the occluded image (3 channels) concatenated with the binary occlusion mask (1 channel) as a 4-channel input. It predicts a residual correction through a small convolutional network, then adds that correction back to the original pixels. The learnable `gate` scalar, initialized to `0.10`, controls how strongly the adapter modifies the image. It starts as a near-identity transformation and gradually grows during training as the adapter gains confidence.

The adapter is plugged into Qwen2.5-VL via a Lightning module:

```
class QwenVLAdapterModule(L.LightningModule):
    def __init__(
        self,
        model_name: str,
        learning_rate: float,
        weight_decay: float,
        reconstruction_loss_weight: float,
    ):
        super().__init__()
        from transformers import Qwen2_5_VLForConditionalGeneration

        self.save_hyperparameters()
        self.adapter = ResidualOcclusionAdapter()

        self.backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            attn_implementation="sdpa",
        )
        self.backbone.requires_grad_(False)
        self.backbone.gradient_checkpointing_enable()

        # DeepSpeed checkpoints only persist the trainable adapter weights when
        # `exclude_frozen_parameters=True`. On resume we rebuild the frozen
        # backbone from Hugging Face and load the checkpoint non-strictly.
        self.strict_loading = False

        self.total_params, self.trainable_params = count_parameters(self)
        self.example_input_array = None
        self.vision_patch_size = int(self.backbone.config.vision_config.patch_size)
        self.temporal_patch_size = int(
            getattr(self.backbone.config.vision_config, "temporal_patch_size", 1)
        )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/model.py*

The key line is `self.backbone.requires_grad_(False)`. This freezes all 3 billion backbone parameters which means only the adapter's ~10,500 weights receive gradients. `gradient_checkpointing_enable()` trades compute for memory: instead of keeping the frozen backbone's intermediate activations in GPU memory during the backward pass, they're recomputed on the fly. This is critical when a 3B model is sitting in GPU memory alongside your optimizer state.

`strict_loading = False` handles an important DeepSpeed checkpoint detail. When `exclude_frozen_parameters=True` is set on the strategy, DeepSpeed only saves the adapter weights in checkpoints, not the 3B frozen backbone. On resume, the checkpoint won't contain backbone weights, so loading must be non-strict. The `on_load_checkpoint` hook fills in the missing backbone weights from the freshly-loaded HuggingFace model, combining the best of both worlds: small checkpoints and a fully initialized model.

The training loss combines two objectives:

```
    def _forward_losses(
        self, batch: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        backbone_dtype = next(self.backbone.parameters()).dtype
        if batch["pixel_values"].ndim == 2:
            if "image_grid_thw" not in batch:
                raise ValueError(
                    "Packed Qwen pixel values require `image_grid_thw` to reconstruct "
                    "dense images for the Conv2d adapter."
                )
            grid_thw = batch["image_grid_thw"]
            dense_pixels = packed_pixels_to_dense_images(
                batch["pixel_values"].to(dtype=backbone_dtype),
                grid_thw,
                patch_size=self.vision_patch_size,
                temporal_patch_size=self.temporal_patch_size,
            )
            clean_pixels = packed_pixels_to_dense_images(
                batch["clean_pixel_values"].to(dtype=backbone_dtype),
                grid_thw,
                patch_size=self.vision_patch_size,
                temporal_patch_size=self.temporal_patch_size,
            )
            adapted_dense = self.adapter(dense_pixels, batch["occlusion_mask"])
            adapted_pixels = dense_images_to_packed_pixels(
                adapted_dense,
                grid_thw,
                patch_size=self.vision_patch_size,
                temporal_patch_size=self.temporal_patch_size,
            )
        else:
            clean_pixels = batch["clean_pixel_values"].to(dtype=backbone_dtype)
            adapted_dense = self.adapter(
                batch["pixel_values"].to(dtype=backbone_dtype),
                batch["occlusion_mask"],
            )
            adapted_pixels = adapted_dense

        forward_kwargs = {
            "input_ids": batch["input_ids"],
            "attention_mask": batch["attention_mask"],
            "pixel_values": adapted_pixels,
            "labels": batch["labels"],
        }
        if "image_grid_thw" in batch:
            forward_kwargs["image_grid_thw"] = batch["image_grid_thw"]
        outputs = self.backbone(**forward_kwargs)

        clean_pixels = clean_pixels.to(
            device=adapted_pixels.device, dtype=backbone_dtype
        )
        occlusion_mask = batch["occlusion_mask"].to(
            device=adapted_pixels.device,
            dtype=backbone_dtype,
        )
        if occlusion_mask.ndim == 3:
            occlusion_mask = occlusion_mask.unsqueeze(1)
        if occlusion_mask.shape[-2:] != adapted_dense.shape[-2:]:
            occlusion_mask = F.interpolate(
                occlusion_mask,
                size=adapted_dense.shape[-2:],
                mode="nearest",
            )

        reconstruction_error = (adapted_dense - clean_pixels).abs() * occlusion_mask
        mask_denominator = (occlusion_mask.sum() * adapted_dense.shape[1]).clamp_min(
            1.0
        )

        reconstruction_loss = reconstruction_error.sum() / mask_denominator
        total_loss = (
            outputs.loss + self.hparams.reconstruction_loss_weight * reconstruction_loss
        )

        return {
            "total_loss": total_loss,
            "lm_loss": outputs.loss,
            "reconstruction_loss": reconstruction_loss,
        }
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/model.py*

The **language modeling loss** (cross-entropy on the predicted class label tokens) drives the model to produce correct answers. The **reconstruction loss** (mean absolute error between the adapter's output and the clean image, computed only in the occluded region) pushes the adapter to actually restore the missing pixels rather than finding a representation shortcut. Without it, the adapter could overfit the frozen backbone's quirks and produce correct tokens while generating noise in the masked region. The `reconstruction_loss_weight` (default `0.35`) balances these two objectives.

Because Qwen2.5-VL's preprocessor packs image patches into a flat `(num_patches, patch_dim)` tensor, the adapter must unpack this into a spatial `(B, C, H, W)` tensor, apply the convolutions, then repack. The `packed_pixels_to_dense_images` and `dense_images_to_packed_pixels` utilities in `model.py` handle this format conversion transparently.

### Multi-node training with DeepSpeed

The training task is a standard PyTorch Lightning training loop with distributed infrastructure handled by Flyte and DeepSpeed:

```
@wandb_init
@training_env.task(report=True)
def train_qwen_adapter_multinode(
    train_manifest: JsonlFile,
    val_manifest: JsonlFile,
    images_dir: Dir,
    config: Config,
    resume_from: Optional[Dir] = None,
    recovery_uri: Optional[str] = None,
) -> Optional[Dir]:
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

The `@wandb_init` decorator integrates with the `wandb_config` context created in the driver task. It retrieves the initialized WandB run and attaches a `WandbLogger` to the trainer. The `report=True` flag on the task decorator enables Flyte Reports for live dashboard streaming from this task.

![Live Training](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/live_training_graph.png)
![Live Training Contd](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/losses.png)

DeepSpeed Stage 2 shards optimizer states and gradients across GPUs, reducing per-GPU memory usage significantly. The critical configuration flag here is `exclude_frozen_parameters=True`:

```
    strategy = DeepSpeedStrategy(
        stage=2,
        offload_optimizer=False,
        offload_parameters=False,
        process_group_backend="nccl",
        exclude_frozen_parameters=True,
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

Without `exclude_frozen_parameters=True`, DeepSpeed would shard and checkpoint the frozen backbone weights too, producing enormous checkpoint files, slow checkpoint saves, and unnecessary communication overhead. With it, only the adapter participates in sharding and checkpointing. The backbone is loaded independently on each worker from HuggingFace.

Gradient accumulation is computed automatically to hit the target global batch size regardless of how many GPUs are actually running:

```
    world_size = NUM_NODES * DEVICES_PER_NODE
    per_step_batch = world_size * config.per_device_batch_size
    grad_accum_steps = max(
        1,
        math.ceil(config.target_global_batch_size / max(1, per_step_batch)),
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

With 2 nodes × 4 GPUs × per-device batch size 1, the effective per-step batch is 8. To reach the default target of 16, the trainer accumulates over 2 steps. Change `NUM_NODES` or `per_device_batch_size` and the calculation adjusts automatically.

The trainer brings everything together:

```
    trainer = L.Trainer(
        accelerator="gpu",
        devices=DEVICES_PER_NODE,
        num_nodes=NUM_NODES,
        strategy=strategy,
        logger=wandb_logger,
        precision="bf16-mixed",
        max_epochs=config.epochs,
        accumulate_grad_batches=grad_accum_steps,
        callbacks=[
            checkpoint_callback,
            metrics_callback,
            recovery_callback,
            live_report_callback,
        ],
        gradient_clip_val=1.0,
        benchmark=True,
        log_every_n_steps=1,
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

`precision="bf16-mixed"` uses BFloat16, which matches FP32's dynamic range (unlike FP16), so you don't need loss scaling. This is the standard choice for modern VLM training. `benchmark=True` runs cuDNN autotuning on the first batch to select the fastest kernels for your specific input sizes.

### Fault tolerance and recovery

Multi-node GPU jobs fail. Hardware hiccups, spot instance preemptions, NCCL timeouts, memory spikes, etc. and the question is when, not if. This pipeline handles it with a two-part system.

After every validation epoch, the `RecoveryArtifactCallback` calls `trainer.save_checkpoint()` to write a DeepSpeed checkpoint directory, then uploads all shard files to the recovery URI. Each node's local rank 0 uploads its own shards; global rank 0 uploads the metadata files (`metrics.json`, `summary.json`). A distributed barrier between save and upload ensures all workers finish before training continues.

If training fails, the driver task catches the error and returns the last recovery artifact instead of propagating the failure:

```
    try:
        with wandb_config(
            project=wandb_project,
            entity=wandb_entity,
        ):
            training_artifacts = train_qwen_adapter_multinode(
                train_manifest=train_manifest,
                val_manifest=val_manifest,
                images_dir=images,
                config=config,
                resume_from=resume_training_artifacts,
                recovery_uri=recovery_uri,
            )
    except flyte.errors.RuntimeUserError as e:
        if recovery_uri is None:
            raise e
        print(f"Training failed - recovering latest checkpoint bundle: {recovery_uri}")
        try:
            recovered_artifacts = Dir(path=recovery_uri)
            recovered_root = await download_dir_async(recovered_artifacts)
            flyte.report.log(
                build_qwen_adapter_report_html(recovered_root, None),
                do_flush=True,
            )
            return recovered_artifacts
        except Exception:
            raise e
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

A failed run still produces useful output: the best checkpoint reached before the failure, along with a partial training report. To resume from that point, pass the recovery artifact as `resume_training_artifacts` on the next run. The training task downloads it, finds the most recent `.ckpt` file, and passes it to `trainer.fit()` as `ckpt_path`. Training picks up at the last saved epoch with optimizer state and metrics history intact.

The recovery URI is constructed from the configurable base path and the run name:

```
s3://your-bucket/qwen-vl-multinode-deepspeed/<run-name>/qwen_vl_training_recovery/
```

This means each run gets its own recovery location, so you can identify exactly which run a checkpoint came from.

### Live observability

`flyte.report` lets you push HTML content directly into the Flyte UI during task execution, with no separate monitoring infrastructure. The `LiveTrainingReportCallback` uses this to stream training metrics in real-time:

```
    def _push_update(
        self,
        *,
        trainer,
        pl_module,
        status: str,
        phase: str,
        train_total=None,
        train_lm=None,
        train_recon=None,
        val_total=None,
        note: str,
    ) -> None:
        adapter_gate = float(torch.tanh(pl_module.adapter.gate).detach().cpu())

        def fmt(value):
            return f"{float(value):.4f}" if value is not None else "-"

        payload = {
            "step": trainer.global_step,
            "phase": phase,
            "train_total": fmt(train_total),
            "train_lm": fmt(train_lm),
            "train_recon": fmt(train_recon),
            "val_total": fmt(val_total),
            "train_total_value": (
                float(train_total) if train_total is not None else None
            ),
            "val_total_value": float(val_total) if val_total is not None else None,
            "adapter_gate": f"{adapter_gate:.4f}",
            "status": status,
            "resumed_from": self.resumed_from or "fresh run",
            "recovery_path": self.recovery_callback.latest_path
            or "pending first checkpoint upload",
            "note": note,
        }
        flyte.report.log(
            f"""
            <script>
            if (typeof window.updateQwenLiveReport === "function") {{
                window.updateQwenLiveReport({json.dumps(payload)});
            }}
            </script>
            """,
            do_flush=True,
        )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/callbacks.py*

`on_train_start` (see the full code) initializes the dashboard with an SVG loss chart and an HTML metrics table. Every `report_every_n_steps` training steps, `_push_update` serializes the latest metrics into a `<script>` block and calls `flyte.report.log()` to append it to the live page. The JavaScript `updateQwenLiveReport()` function then updates the chart polylines and appends a new table row for each step.

For resumed runs, the prior metrics history is seeded into the table on `on_train_start`, so the metrics view is continuous across runs rather than restarting from zero.

![Recovery](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/recovery.png)

WandB metrics are logged in parallel by `AdapterMetricsCallback` after each validation epoch, including per-epoch train and validation losses, the LM loss component, the reconstruction loss component, and the current adapter gate value.

![WandB](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/wandb.png)

### Evaluation

After training completes, a separate task runs inference on a single GPU:

```
@evaluation_env.task
async def evaluate_qwen_adapter(
    val_manifest: JsonlFile,
    images_dir: Dir,
    training_artifacts: Dir,
    config: Config,
) -> Dir:
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

The task is async so the driver can `asyncio.gather` the downloads of training artifacts and images in parallel rather than sequentially, a simple speedup that matters when downloading hundreds of megabytes from object storage.

The evaluation task loads the adapter state dict from `adapter_artifact.pt`, rebuilds the frozen backbone fresh from HuggingFace (there's no need to checkpoint 3B weights, only the ~10,500 adapter weights travel with the artifact), and runs greedy decoding on each validation example. The metric is exact-match accuracy between the model's predicted class token and the ground-truth CIFAR-10 label.

### Putting it all together

The driver task is the pipeline entry point that all other tasks flow through:

```
@driver_env.task(report=True)
async def qwen_vl_multinode_deepspeed(
    model_name: str = DEFAULT_MODEL_NAME,
    max_train_samples: int = 1024,
    max_val_samples: int = 256,
    epochs: int = 8,
    per_device_batch_size: int = 1,
    target_global_batch_size: int = 16,
    learning_rate: float = 2e-4,
    reconstruction_loss_weight: float = 0.35,
    eval_examples: int = 16,
    resume_training_artifacts: Optional[Dir] = None,
    checkpoint_base_uri: Optional[str] = DEFAULT_CHECKPOINT_BASE_URI,
    wandb_project: str = "qwen-vl-multinode-deepspeed",
    wandb_entity: Optional[str] = None,
) -> Optional[Dir]:
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/tasks.py*

The driver constructs the recovery URI from `checkpoint_base_uri` and the current run name, prepares the dataset (or retrieves it from cache), then executes training inside a `wandb_config` context. The `wandb_config` context manager creates and registers a WandB run; the `@wandb_init` decorator on the training task retrieves it, updates it with the full `Config` dataclass, and attaches a `WandbLogger`. Neither the training task nor the callbacks need any WandB initialization code of their own.

The recovery handler (shown in the previous section) wraps the training call. If training succeeds, evaluation runs next. The driver then downloads both the training and evaluation artifacts concurrently and assembles a final HTML report with training curves, evaluation summary, per-epoch metrics table, and sample prediction cards with the actual occluded images, which is pushed to Flyte Reports.

## Running the tutorial

Before running, update two placeholders in `config.py`:

- `DEFAULT_CHECKPOINT_BASE_URI`: your S3, GCS, or Azure Blob URI for checkpoint storage
- The `wandb_api_key` secret key name to match your cluster's secret store configuration

Then configure and launch:

```
if __name__ == "__main__":
    flyte.init_from_config()

    run = flyte.run(
        qwen_vl_multinode_deepspeed,
        model_name=DEFAULT_MODEL_NAME,
        max_train_samples=512,
        max_val_samples=128,
        epochs=5,
        per_device_batch_size=1,
        target_global_batch_size=16,
        learning_rate=2e-4,
        reconstruction_loss_weight=0.35,
        eval_examples=16,
        checkpoint_base_uri=DEFAULT_CHECKPOINT_BASE_URI,
        wandb_project="qwen-vl-multinode-deepspeed",
        wandb_entity="<YOUR_WANDB_ENTITY>",  # TODO: update with your own wandb entity
        # resume_training_artifacts=Dir(
        #     path="s3://flyte-examples/qwen-vl-multinode-deepspeed/<ACTION_NAME>/qwen_vl_training_recovery/"
        # ),
    )

    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/qwen_vl_frozen_backbone_finetuning/train.py*

```bash
flyte create config --endpoint <YOUR_ENDPOINT> --project <PROJECT> --domain <DOMAIN> --builder remote
uv run train.py
```

When you run this, the pipeline:

1. **Builds containers** once and caches them for subsequent runs
2. **Prepares the dataset**: downloads CIFAR-10, generates occlusions, writes JSONL manifests; cached on subsequent runs with the same config
3. **Launches multi-node training**: provisions 2 × 4 L40s GPUs and starts the Elastic job
4. **Streams metrics to the live dashboard**: the Flyte Reports view starts updating as soon as the first step logs
5. **Runs evaluation**: a single-GPU task loads the adapter and runs inference, computing exact-match accuracy
6. **Generates the final report**: training curves, evaluation summary, and sample prediction cards appear in the Flyte UI

![Final Report](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/final_report.png)
![Predictions](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/qwen-vl-finetuning/predictions.png)

To resume a failed or interrupted run, uncomment the `resume_training_artifacts` line in `train.py` and point it to the recovery URI from the previous run. Training picks up from the last checkpoint with metrics history intact.

## Going further

**Adapting this to a different task.** The frozen backbone pattern transfers directly. Replace `QwenOcclusionDataset` and `prepare_occlusion_dataset` with your own data, update the prompt template, and adjust the dual loss if a pixel-level reconstruction term doesn't apply to your task. The multi-node Elastic setup, DeepSpeed Stage 2 config, recovery system, and live reporting are completely reusable.

**Using a larger Qwen model.** Change `DEFAULT_MODEL_NAME` to `Qwen/Qwen2.5-VL-7B-Instruct` or a larger variant. You may need to increase `memory` in `training_env` and reduce `per_device_batch_size`. The frozen backbone + adapter pattern becomes more valuable at larger scales where you're always training the same ~10,500-parameter adapter regardless of backbone size.

**Training keeps failing.** Add `retries=3` to the `@training_env.task` decorator. With the recovery callback uploading checkpoints after every validation epoch, Flyte automatically restarts training from the last checkpoint on transient failures. Spot instance preemptions and most hardware hiccups become non-events.

**Scaling to more nodes.** Increase `NUM_NODES` in `config.py`. The Elastic plugin, DeepSpeed strategy, and gradient accumulation calculation all adapt automatically. The recovery system is unchanged as each run still gets its own recovery URI.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/climate-modeling ===

# GPU-accelerated climate modeling

Climate modeling is hard for two reasons: data and compute. Satellite imagery arrives continuously from multiple sources. Reanalysis datasets have to be pulled from remote APIs. Weather station data shows up in different formats and schemas. And once all of that is finally in one place, running atmospheric physics simulations demands serious GPU compute.

In practice, many climate workflows are held together with scripts, cron jobs, and a lot of manual babysitting. Data ingestion breaks without warning. GPU jobs run overnight with little visibility into what's happening. When something interesting shows up in a simulation, like a developing hurricane, no one notices until the job finishes hours later.

In this tutorial, we build a production-grade climate modeling pipeline using Flyte. We ingest data from three different sources in parallel, combine it with Dask, run ensemble atmospheric simulations on H200 GPUs, detect extreme weather events as they emerge, and visualize everything in a live dashboard. The entire pipeline is orchestrated, cached, and fault-tolerant, so it can run reliably at scale.

![Report](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/tutorials/climate-modeling/report.png)

> [!NOTE]
> Full code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/climate_modeling/simulation.py).

## Overview

We're building an ensemble weather forecasting system. Ensemble forecasting runs the same simulation multiple times with slightly different initial conditions. This quantifies forecast uncertainty. Instead of saying "the temperature will be 25°C", we can say "the temperature will be 24-26°C with 90% confidence".

The pipeline has five stages:

1. **Data ingestion**: Pull satellite imagery from NOAA GOES, reanalysis data from ERA5, and surface observations from weather stations in parallel.
2. **Preprocessing**: Fuse the datasets, interpolate to a common grid, and run quality control using Dask for distributed computation.
3. **GPU simulation**: Run ensemble atmospheric physics on H200 GPUs. Each ensemble member evolves independently. PyTorch handles the tensor operations; `torch.compile` optimizes the kernels.
4. **Event detection**: Monitor for hurricanes (high wind + low pressure) and heatwaves during simulation. When extreme events are detected, the pipeline can adaptively refine the grid resolution.
5. **Real-time reporting**: Stream metrics to a live Flyte Reports dashboard showing convergence and detected events.

This workflow is a good example of where Flyte shines!

- **Parallel data ingestion**: Three different data sources, three different APIs, all running concurrently. Flyte's async task execution handles this naturally.
- **Resource heterogeneity**: Data ingestion needs CPU and network. Preprocessing needs a Dask cluster. Simulation needs GPUs. Flyte provisions exactly what each stage needs.
- **Caching**: ERA5 data fetches can take minutes. Run the pipeline twice with the same date range, and Flyte skips the fetch entirely.
- **Adaptive workflows**: When a hurricane is detected, we can dynamically refine the simulation. Flyte makes this kind of conditional logic straightforward.

## Implementation

### Dependencies and container image

```
import asyncio
import gc
import io
import json
import os
import tempfile
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Literal

import flyte
import numpy as np
import pandas as pd
import xarray as xr
from flyte.io import File
from flyteplugins.dask import Dask, Scheduler, WorkerGroup
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

The key imports include `xarray` for multi-dimensional climate data, `flyteplugins.dask` for distributed preprocessing, and `flyte` for orchestration.

```
climate_image = (
    flyte.Image.from_debian_base(name="climate_modeling_h200")
    .with_apt_packages(
        "libnetcdf-dev",  # NetCDF for climate data
        "libhdf5-dev",  # HDF5 for large datasets
        "libeccodes-dev",  # GRIB format support (ECMWF's native format)
        "libudunits2-dev",  # Unit conversions
    )
    .with_pip_packages(
        "numpy==2.3.5",
        "pandas==2.3.3",
        "xarray==2025.11.0",
        "torch==2.9.1",
        "netCDF4==1.7.3",
        "s3fs==2025.10.0",
        "aiohttp==3.13.2",
        "ecmwf-datastores-client==0.4.1",
        "h5netcdf==1.7.3",
        "cfgrib==0.9.15.1",
        "pyarrow==22.0.0",
        "scipy==1.15.1",
        "flyteplugins-dask>=2.0.0b33",
        "nvidia-ml-py3==7.352.0",
    )
    .with_env_vars({"PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512"})
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

Climate data comes in specialized formats such as NetCDF, HDF5, and GRIB. The container image includes libraries to work with all of them, along with PyTorch for GPU computation and the ECMWF client for accessing ERA5 data.

### Simulation parameters and data structures

```
@dataclass
class SimulationParams:
    grid_resolution_km: float = 10.0
    time_step_minutes: int = 10
    simulation_hours: int = 240
    physics_model: Literal["WRF", "MPAS", "CAM"] = "WRF"
    boundary_layer_scheme: str = "YSU"
    microphysics_scheme: str = "Thompson"
    radiation_scheme: str = "RRTMG"

    # Ensemble forecasting parameters
    ensemble_size: int = 800
    perturbation_magnitude: float = 0.5

    # Convergence criteria for adaptive refinement
    convergence_threshold: float = 0.1  # 10% of initial ensemble spread
    max_iterations: int = 3

@dataclass
class ClimateMetrics:
    timestamp: str
    iteration: int
    convergence_rate: float
    energy_conservation_error: float
    max_wind_speed_mps: float
    min_pressure_mb: float
    detected_phenomena: list[str]
    compute_time_seconds: float
    ensemble_spread: float

@dataclass
class SimulationSummary:
    total_iterations: int
    final_resolution_km: float
    avg_convergence_rate: float
    total_compute_time_seconds: float
    hurricanes_detected: int
    heatwaves_detected: int
    converged: bool
    region: str
    output_files: list[File]
    date_range: list[str, str]
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

`SimulationParams` defines the core behavior of the simulation, including grid resolution, physics schemes, and ensemble size. The default configuration runs 800 ensemble members, which is sufficient to produce statistically meaningful uncertainty estimates.

> [!NOTE]
> Decreasing the grid spacing via `grid_resolution_km` (for example, from 10 km to 5 km) increases grid resolution and significantly increases memory usage because it introduces more data points and intermediate state. Even with 141 GB of H200 GPU memory, high-resolution or adaptively refined simulations may exceed available VRAM, especially when running large ensembles.
>
> To mitigate this, consider reducing the ensemble size, limiting the refined region, running fewer physics variables, or scaling the simulation across more GPUs so memory is distributed more evenly.

`ClimateMetrics` collects diagnostics at each iteration, such as convergence rate, energy conservation, and detected phenomena. These metrics are streamed to the real-time dashboard so you can monitor how the simulation evolves as it runs.

### Task environments

Different stages need different resources. Flyte's `TaskEnvironment` declares exactly what each task requires:

```
gpu_env = flyte.TaskEnvironment(
    name="climate_modeling_gpu",
    resources=flyte.Resources(
        cpu=5,
        memory="130Gi",
        gpu="H200:1",
    ),
    image=climate_image,
    cache="auto",
)

dask_env = flyte.TaskEnvironment(
    name="climate_modeling_dask",
    plugin_config=Dask(
        scheduler=Scheduler(resources=flyte.Resources(cpu=2, memory="6Gi")),
        workers=WorkerGroup(
            number_of_workers=2,
            resources=flyte.Resources(cpu=2, memory="12Gi"),
        ),
    ),
    image=climate_image,
    resources=flyte.Resources(cpu=2, memory="12Gi"),  # Head node
    cache="auto",
)

cpu_env = flyte.TaskEnvironment(
    name="climate_modeling_cpu",
    resources=flyte.Resources(cpu=8, memory="64Gi"),
    image=climate_image,
    cache="auto",
    secrets=[
        flyte.Secret(key="cds_api_key", as_env_var="ECMWF_DATASTORES_KEY"),
        flyte.Secret(key="cds_api_url", as_env_var="ECMWF_DATASTORES_URL"),
    ],
    depends_on=[gpu_env, dask_env],
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

Here’s what each environment is responsible for:

- **`gpu_env`**: Runs the atmospheric simulations on H200 GPUs. The 130 GB of GPU memory is used to hold the ensemble members in VRAM during execution.
- **`dask_env`**: Provides a distributed Dask cluster for preprocessing. A scheduler and multiple workers handle data fusion and transformation in parallel.
- **`cpu_env`**: Handles data ingestion and orchestration. This environment also includes the secrets required to access the ERA5 API.

The `depends_on` setting on `cpu_env` ensures that Flyte builds the GPU and Dask images first. Once those environments are ready, the orchestration task can launch the specialized simulation and preprocessing tasks.

### Data ingestion: multiple sources in parallel

Climate models need data from multiple sources. Each source has different formats, APIs, and failure modes. We handle them as separate Flyte tasks that run concurrently.

**Satellite imagery from NOAA GOES**

```
@cpu_env.task
async def ingest_satellite_data(region: str, date_range: list[str, str]) -> File:
    """Ingest GOES satellite imagery from NOAA's public S3 buckets."""
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

This task fetches cloud imagery and precipitable water products from NOAA's public S3 buckets. GOES-16 covers the Atlantic; GOES-17 covers the Pacific. The task selects the appropriate satellite based on region, fetches multiple days in parallel using `asyncio.gather`, and combines everything into a single xarray Dataset.

**ERA5 reanalysis from Copernicus**

```
@cpu_env.task
async def ingest_reanalysis_data(region: str, date_range: list[str, str]) -> File:
    """Fetch ERA5 reanalysis from Copernicus Climate Data Store."""
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

ERA5 provides 3D atmospheric fields such as temperature, wind, humidity at multiple pressure levels from surface to stratosphere. The ECMWF datastores client handles authentication via Flyte secrets. Each day fetches in parallel, then gets concatenated.

**Surface observations from weather stations:**

```
@cpu_env.task
async def ingest_station_data(
    region: str, date_range: list[str, str], max_stations: int = 100
) -> File:
    """Fetch ground observations from NOAA's Integrated Surface Database."""
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

Ground truth comes from NOAA's Integrated Surface Database. The task filters stations by geographic bounds, fetches hourly observations, and returns a Parquet file for efficient downstream processing.

All three tasks return Flyte `File` objects that hold references to data in blob storage. No data moves until a downstream task actually needs it.

### Preprocessing with Dask

The three data sources need to be combined into a unified atmospheric state. This means:
- Interpolating to a common grid
- Handling missing values
- Merging variables from different sources
- Quality control

This is a perfect fit for Dask to handle lazy evaluation over chunked arrays:

```python
@dask_env.task
async def preprocess_atmospheric_data(
    satellite_data: File,
    reanalysis_data: File,
    station_data: File,
    target_resolution_km: float,
) -> File:
```

This task connects to the Dask cluster provisioned by Flyte, loads the datasets with appropriate chunking, merges satellite and reanalysis grids, fills in missing values, and persists the result. Flyte caches the output, so preprocessing only runs when the inputs change.

### GPU-accelerated atmospheric simulation

Now the core: running atmospheric physics on the GPU. Each ensemble member is an independent forecast with slightly perturbed initial conditions.

```
@gpu_env.task
async def run_atmospheric_simulation(
    input_data: File,
    params: SimulationParams,
    partition_id: int = 0,
    ensemble_start: int | None = None,
    ensemble_end: int | None = None,
) -> tuple[File, ClimateMetrics]:
    """Run GPU-accelerated atmospheric simulation with ensemble forecasting."""
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

The task accepts a subset of ensemble members (`ensemble_start` to `ensemble_end`). This enables distributing 800 members across multiple GPUs.

The physics step is the computational kernel. It runs advection (wind transport), pressure gradients, Coriolis forces, turbulent diffusion, and moisture condensation:

```
    @torch.compile(mode="reduce-overhead")
    def physics_step(state_tensor, dt_val, dx_val):
        """Compiled atmospheric physics - 3-4x faster with torch.compile."""
        # Advection: transport by wind
        temp_grad_x = torch.roll(state_tensor[:, 0], -1, dims=2) - torch.roll(
            state_tensor[:, 0], 1, dims=2
        )
        temp_grad_y = torch.roll(state_tensor[:, 0], -1, dims=3) - torch.roll(
            state_tensor[:, 0], 1, dims=3
        )
        advection = -(
            state_tensor[:, 3] * temp_grad_x + state_tensor[:, 4] * temp_grad_y
        ) / (2 * dx_val)
        state_tensor[:, 0] = state_tensor[:, 0] + advection * dt_val

        # Pressure gradient with Coriolis
        pressure_grad_x = (
            torch.roll(state_tensor[:, 1], -1, dims=2)
            - torch.roll(state_tensor[:, 1], 1, dims=2)
        ) / (2 * dx_val)
        pressure_grad_y = (
            torch.roll(state_tensor[:, 1], -1, dims=3)
            - torch.roll(state_tensor[:, 1], 1, dims=3)
        ) / (2 * dx_val)

        coriolis_param = 1e-4  # ~45°N latitude
        coriolis_u = coriolis_param * state_tensor[:, 4]
        coriolis_v = -coriolis_param * state_tensor[:, 3]

        state_tensor[:, 3] = (
            state_tensor[:, 3] - pressure_grad_x * dt_val * 0.01 + coriolis_u * dt_val
        )
        state_tensor[:, 4] = (
            state_tensor[:, 4] - pressure_grad_y * dt_val * 0.01 + coriolis_v * dt_val
        )

        # Turbulent diffusion
        diffusion_coeff = 10.0
        laplacian_temp = (
            torch.roll(state_tensor[:, 0], 1, dims=2)
            + torch.roll(state_tensor[:, 0], -1, dims=2)
            + torch.roll(state_tensor[:, 0], 1, dims=3)
            + torch.roll(state_tensor[:, 0], -1, dims=3)
            - 4 * state_tensor[:, 0]
        ) / (dx_val * dx_val)
        state_tensor[:, 0] = (
            state_tensor[:, 0] + diffusion_coeff * laplacian_temp * dt_val
        )

        # Moisture condensation
        sat_vapor_pressure = 611.2 * torch.exp(
            17.67 * state_tensor[:, 0] / (state_tensor[:, 0] + 243.5)
        )
        condensation = torch.clamp(
            state_tensor[:, 2] - sat_vapor_pressure * 0.001, min=0
        )
        state_tensor[:, 2] = state_tensor[:, 2] - condensation * 0.1
        state_tensor[:, 0] = state_tensor[:, 0] + condensation * 2.5e6 / 1005 * dt_val

        return state_tensor
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

`@torch.compile(mode="reduce-overhead")` compiles this function into optimized CUDA kernels. Combined with mixed precision (`torch.cuda.amp.autocast`), this runs 3-4x faster than eager PyTorch.

Every 10 timesteps, the simulation checks for extreme events:
- **Hurricanes**: Wind speed > 33 m/s with low pressure
- **Heatwaves**: Temperature anomalies exceeding thresholds

Detected phenomena get logged to the metrics, which flow to the live dashboard.

### Distributing across multiple GPUs

800 ensemble members is a lot for one GPU, so we distribute them:

```
@cpu_env.task
async def run_distributed_simulation_ensemble(
    preprocessed_data: File, params: SimulationParams, n_gpus: int
) -> tuple[list[File], list[ClimateMetrics]]:
    total_members = params.ensemble_size
    members_per_gpu = total_members // n_gpus

    # Distribute ensemble members across GPUs
    tasks = []
    for gpu_id in range(n_gpus):
        # Calculate ensemble range for this GPU
        ensemble_start = gpu_id * members_per_gpu
        # Last GPU gets any remainder members
        if gpu_id == n_gpus - 1:
            ensemble_end = total_members
        else:
            ensemble_end = ensemble_start + members_per_gpu

        # Launch GPU task with ensemble subset
        gpu_task = run_atmospheric_simulation(
            preprocessed_data,
            params,
            gpu_id,
            ensemble_start=ensemble_start,
            ensemble_end=ensemble_end,
        )
        tasks.append(gpu_task)

    # Execute all GPUs in parallel
    results = await asyncio.gather(*tasks)

    output_files = [r[0] for r in results]
    metrics = [r[1] for r in results]

    return output_files, metrics
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

The task splits the ensemble members evenly across the available GPUs, launches the simulation runs in parallel using `asyncio.gather`, and then aggregates the results. With five GPUs, each GPU runs 160 ensemble members. Flyte takes care of scheduling, so GPU tasks start automatically as soon as resources become available.

### The main workflow

Everything comes together in the orchestration task:

```
@cpu_env.task(report=True)
async def adaptive_climate_modeling_workflow(
    region: str = "atlantic",
    date_range: list[str, str] = ["2024-09-01", "2024-09-10"],
    current_params: SimulationParams = SimulationParams(),
    enable_multi_gpu: bool = True,
    n_gpus: int = 5,
) -> SimulationSummary:
    """Orchestrates multi-source ingestion, GPU simulation, and adaptive refinement."""
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

`report=True` enables Flyte Reports for live monitoring.

```
    # Parallel data ingestion from three sources
    with flyte.group("data-ingestion"):
        satellite_task = ingest_satellite_data(region, date_range)
        reanalysis_task = ingest_reanalysis_data(region, date_range)
        station_task = ingest_station_data(region, date_range)

        satellite_data, reanalysis_data, station_data = await asyncio.gather(
            satellite_task,
            reanalysis_task,
            station_task,
        )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

`flyte.group("data-ingestion")` visually groups the ingestion tasks in the Flyte UI. Inside the group, three tasks launch concurrently. `asyncio.gather` waits for all three to complete before preprocessing begins.

The workflow then enters an iterative loop:
1. Run GPU simulation (single or multi-GPU)
2. Check convergence by comparing forecasts across iterations
3. Detect extreme events
4. If a hurricane is detected and we haven't refined yet, double the grid resolution
5. Stream metrics to the live dashboard
6. Repeat until converged or max iterations reached

Adaptive mesh refinement is the key feature here. When the simulation detects a hurricane forming, it automatically increases resolution to capture the fine-scale dynamics. This is expensive, so we limit it to one refinement per run.

### Running the pipeline

```
if __name__ == "__main__":
    flyte.init_from_config()
    run_multi_gpu = flyte.run(adaptive_climate_modeling_workflow)

    print(f"Run URL: {run_multi_gpu.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/climate_modeling/simulation.py*

Before running, set up ERA5 API credentials:

```bash
flyte create secret cds_api_key <YOUR_CDS_API_KEY>
flyte create secret cds_api_url https://cds.climate.copernicus.eu/api
```

Then launch:

```bash
flyte create config --endpoint <FLYTE_OR_UNION_ENDPOINT> --project <PROJECT_NAME> --domain <DOMAIN_NAME> --builder remote
uv run simulation.py
```

The default configuration uses the Atlantic region for September 2024, which is hurricane season.

## Key concepts

### Ensemble forecasting

Weather prediction is inherently uncertain. Small errors in the initial conditions grow over time due to chaotic dynamics, which means a single forecast can only ever be one possible outcome.

Ensemble forecasting addresses this uncertainty by:
- Perturbing the initial conditions within known observational error bounds
- Running many independent forecasts
- Computing the ensemble mean as the most likely outcome and the ensemble spread as a measure of uncertainty

### Adaptive mesh refinement

When a hurricane begins to form, coarse spatial grids are not sufficient to resolve critical features like eyewall dynamics. Adaptive mesh refinement allows the simulation to focus compute where it matters most by:
- Increasing grid resolution, for example from 10 km to 5 km
- Reducing the timestep to maintain numerical stability
- Refining only the regions of interest instead of the entire domain

This approach is computationally expensive, but it is essential for producing accurate intensity forecasts.

### Real-time event detection

Rather than analyzing results after a simulation completes, this pipeline detects significant events as the simulation runs.

The system monitors for conditions such as:
- **Hurricanes**: Wind speeds exceeding 33 m/s (Category 1 threshold) combined with central pressure below 980 mb
- **Heatwaves**: Sustained temperature anomalies over a defined period

Detecting these events in real time enables adaptive responses, such as refining the simulation or triggering alerts, and supports earlier warnings for extreme weather.

## Where to go next

This example is intentionally scoped to keep the ideas clear, but there are several natural ways to extend it for more realistic workloads.

To model different ocean basins, change the `region` parameter to values like `"pacific"` or `"indian"`. The ingestion tasks automatically adjust to pull the appropriate satellite coverage for each region.

To run longer forecasts, increase `simulation_hours` in `SimulationParams`. The default of 240 hours, or 10 days, is typical for medium-range forecasting, but you can run longer simulations if you have the compute budget.

Finally, the physics step here is deliberately simplified. Production systems usually incorporate additional components such as radiation schemes, boundary layer parameterizations, and land surface models. These can be added incrementally as separate steps without changing the overall structure of the pipeline.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/trading-agents ===

# Multi-agent trading simulation

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/trading_agents); based on work by [TauricResearch](https://github.com/TauricResearch/TradingAgents).

This example walks you through building a multi-agent trading simulation, modeling how agents within a firm might interact, strategize, and make trades collaboratively.

![Trading agents execution visualization](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/trading-agents/execution.png)
_Trading agents execution visualization_

## TL;DR

- You'll build a trading firm made up of agents that analyze, argue, and act, modeled with Python functions.
- You'll use the Flyte SDK to orchestrate this world — giving you visibility, retries, caching, and durability.
- You'll learn how to plug in tools, structure conversations, and track decisions across agents.
- You'll see how agents debate, use context, generate reports, and retain memory via vector DBs.

## What is an agent, anyway?

Agentic workflows are a rising pattern for complex problem-solving with LLMs. Think of agents as:

- An LLM (like GPT-4 or Mistral)
- A loop that keeps them thinking until a goal is met
- A set of optional tools they can call (APIs, search, calculators, etc.)
- Enough tokens to reason about the problem at hand

That's it.

You define tools, bind them to an agent, and let it run, reasoning step-by-step, optionally using those tools, until it finishes.

## What's different here?

We're not building yet another agent framework. You're free to use LangChain, custom code, or whatever setup you like.

What we're giving you is the missing piece: a way to run these workflows **reliably, observably, and at scale, with zero rewrites.**

With Flyte, you get:

- Prompt + tool traceability and full state retention
- Built-in retries, caching, and failure recovery
- A native way to plug in your agents; no magic syntax required

## How it works: step-by-step walkthrough

This simulation is powered by a Flyte task that orchestrates multiple intelligent agents working together to analyze a company's stock and make informed trading decisions.

![Trading agents schema](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/trading-agents/schema.png)
_Trading agents schema_

### Entry point

Everything begins with a top-level Flyte task called `main`, which serves as the entry point to the workflow.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.0.0b52",
#     "akshare==1.16.98",
#     "backtrader==1.9.78.123",
#     "boto3==1.39.9",
#     "chainlit==2.5.5",
#     "eodhd==1.0.32",
#     "feedparser==6.0.11",
#     "finnhub-python==2.4.23",
#     "langchain-experimental==0.3.4",
#     "langchain-openai==0.3.23",
#     "pandas==2.3.0",
#     "parsel==1.10.0",
#     "praw==7.8.1",
#     "pytz==2025.2",
#     "questionary==2.1.0",
#     "redis==6.2.0",
#     "requests==2.32.4",
#     "stockstats==0.6.5",
#     "tqdm==4.67.1",
#     "tushare==1.4.21",
#     "typing-extensions==4.14.0",
#     "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy

import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
    create_neutral_debator,
    create_risky_debator,
    create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
    reflect_bear_researcher,
    reflect_bull_researcher,
    reflect_research_manager,
    reflect_risk_manager,
    reflect_trader,
)

@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
    """Process a full trading signal to extract the core decision."""

    messages = [
        {
            "role": "system",
            "content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
        },
        {"role": "human", "content": full_signal},
    ]

    return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content

async def run_analyst(analyst_name, state, online_tools):
    # Create a copy of the state for isolation
    run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")

    # Run the analyst's chain
    result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)

    # Determine the report key
    report_key = (
        "sentiment_report"
        if analyst_name == "social_media"
        else f"{analyst_name}_report"
    )
    report_value = getattr(result_state, report_key)

    return result_state.messages[1:], report_key, report_value

# {{docs-fragment main}}
@env.task
async def main(
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
    if not selected_analysts:
        raise ValueError(
            "No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
        )

    state = AgentState(
        messages=[{"role": "human", "content": company_name}],
        company_of_interest=company_name,
        trade_date=str(trade_date),
    )

    # Run all analysts concurrently
    results = await asyncio.gather(
        *[
            run_analyst(analyst, deepcopy(state), online_tools)
            for analyst in selected_analysts
        ]
    )

    # Flatten and append all resulting messages into the shared state
    for messages, report_attr, report in results:
        state.messages.extend(messages)
        setattr(state, report_attr, report)

    # Bull/Bear debate loop
    state = await create_bull_researcher(QUICK_THINKING_LLM, state)  # Start with bull
    while state.investment_debate_state.count < 2 * max_debate_rounds:
        current = state.investment_debate_state.current_response
        if current.startswith("Bull"):
            state = await create_bear_researcher(QUICK_THINKING_LLM, state)
        else:
            state = await create_bull_researcher(QUICK_THINKING_LLM, state)

    state = await create_research_manager(DEEP_THINKING_LLM, state)
    state = await create_trader(QUICK_THINKING_LLM, state)

    # Risk debate loop
    state = await create_risky_debator(QUICK_THINKING_LLM, state)  # Start with risky
    while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
        speaker = state.risk_debate_state.latest_speaker
        if speaker == "Risky":
            state = await create_safe_debator(QUICK_THINKING_LLM, state)
        elif speaker == "Safe":
            state = await create_neutral_debator(QUICK_THINKING_LLM, state)
        else:
            state = await create_risky_debator(QUICK_THINKING_LLM, state)

    state = await create_risk_manager(DEEP_THINKING_LLM, state)
    decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)

    return decision, state

# {{/docs-fragment main}}

# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
    await asyncio.gather(
        reflect_bear_researcher(state, returns),
        reflect_bull_researcher(state, returns),
        reflect_trader(state, returns),
        reflect_risk_manager(state, returns),
        reflect_research_manager(state, returns),
    )

    return "Reflection completed."

# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
    returns: str,
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> str:
    _, state = await main(
        selected_analysts,
        max_debate_rounds,
        max_risk_discuss_rounds,
        online_tools,
        company_name,
        trade_date,
    )

    return await reflect_and_store(state, returns)

# {{/docs-fragment reflect_on_decisions}}

# {{docs-fragment execute_main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

    # run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
    # print(run.url)

# {{/docs-fragment execute_main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py*

This task accepts several inputs:

- the list of analysts to run,
- the number of debate and risk discussion rounds,
- a flag to enable online tools,
- the company you're evaluating,
- and the target trading date.

The most interesting parameter here is the list of analysts to run. It determines which analyst agents will be invoked and shapes the overall structure of the simulation. Based on this input, the task dynamically launches agent tasks, running them in parallel.

The `main` task is written as a regular asynchronous Python function wrapped with Flyte's task decorator. No domain-specific language or orchestration glue is needed — just idiomatic Python, optionally using async for better performance. The task environment is configured once and shared across all tasks for consistency.

```
# {{docs-fragment env}}
import flyte

QUICK_THINKING_LLM = "gpt-4o-mini"
DEEP_THINKING_LLM = "o4-mini"

env = flyte.TaskEnvironment(
    name="trading-agents",
    secrets=[
        flyte.Secret(key="finnhub_api_key", as_env_var="FINNHUB_API_KEY"),
        flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY"),
    ],
    image=flyte.Image.from_uv_script("main.py", name="trading-agents", pre=True),
    resources=flyte.Resources(cpu="1"),
    cache="auto",
)

# {{/docs-fragment env}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/flyte_env.py*

### Analyst agents

Each analyst agent comes equipped with a set of tools and a carefully designed prompt tailored to its specific domain. These tools are modular Flyte tasks — for example, downloading financial reports or computing technical indicators — and benefit from Flyte's built-in caching to avoid redundant computation.

```
from datetime import datetime

import pandas as pd
import tools.interface as interface
import yfinance as yf
from flyte_env import env

from flyte.io import File

@env.task
async def get_reddit_news(
    curr_date: str,  # Date you want to get news for in yyyy-mm-dd format
) -> str:
    """
    Retrieve global news from Reddit within a specified time frame.
    Args:
        curr_date (str): Date you want to get news for in yyyy-mm-dd format
    Returns:
        str: A formatted dataframe containing the latest global news
        from Reddit in the specified time frame.
    """

    global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)

    return global_news_result

@env.task
async def get_finnhub_news(
    ticker: str,  # Search query of a company, e.g. 'AAPL, TSM, etc.
    start_date: str,  # Start date in yyyy-mm-dd format
    end_date: str,  # End date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the latest news about a given stock from Finnhub within a date range
    Args:
        ticker (str): Ticker of a company. e.g. AAPL, TSM
        start_date (str): Start date in yyyy-mm-dd format
        end_date (str): End date in yyyy-mm-dd format
    Returns:
        str: A formatted dataframe containing news about the company
        within the date range from start_date to end_date
    """

    end_date_str = end_date

    end_date = datetime.strptime(end_date, "%Y-%m-%d")
    start_date = datetime.strptime(start_date, "%Y-%m-%d")
    look_back_days = (end_date - start_date).days

    finnhub_news_result = interface.get_finnhub_news(
        ticker, end_date_str, look_back_days
    )

    return finnhub_news_result

@env.task
async def get_reddit_stock_info(
    ticker: str,  # Ticker of a company. e.g. AAPL, TSM
    curr_date: str,  # Current date you want to get news for
) -> str:
    """
    Retrieve the latest news about a given stock from Reddit, given the current date.
    Args:
        ticker (str): Ticker of a company. e.g. AAPL, TSM
        curr_date (str): current date in yyyy-mm-dd format to get news for
    Returns:
        str: A formatted dataframe containing the latest news about the company on the given date
    """

    stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)

    return stock_news_results

@env.task
async def get_YFin_data(
    symbol: str,  # ticker symbol of the company
    start_date: str,  # Start date in yyyy-mm-dd format
    end_date: str,  # End date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
    Args:
        symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
        start_date (str): Start date in yyyy-mm-dd format
        end_date (str): End date in yyyy-mm-dd format
    Returns:
        str: A formatted dataframe containing the stock price data
        for the specified ticker symbol in the specified date range.
    """

    result_data = interface.get_YFin_data(symbol, start_date, end_date)

    return result_data

@env.task
async def get_YFin_data_online(
    symbol: str,  # ticker symbol of the company
    start_date: str,  # Start date in yyyy-mm-dd format
    end_date: str,  # End date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
    Args:
        symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
        start_date (str): Start date in yyyy-mm-dd format
        end_date (str): End date in yyyy-mm-dd format
    Returns:
        str: A formatted dataframe containing the stock price data
        for the specified ticker symbol in the specified date range.
    """

    result_data = interface.get_YFin_data_online(symbol, start_date, end_date)

    return result_data

@env.task
async def cache_market_data(symbol: str, start_date: str, end_date: str) -> File:
    data_file = f"{symbol}-YFin-data-{start_date}-{end_date}.csv"

    data = yf.download(
        symbol,
        start=start_date,
        end=end_date,
        multi_level_index=False,
        progress=False,
        auto_adjust=True,
    )
    data = data.reset_index()
    data.to_csv(data_file, index=False)

    return await File.from_local(data_file)

@env.task
async def get_stockstats_indicators_report(
    symbol: str,  # ticker symbol of the company
    indicator: str,  # technical indicator to get the analysis and report of
    curr_date: str,  # The current trading date you are trading on, YYYY-mm-dd
    look_back_days: int = 30,  # how many days to look back
) -> str:
    """
    Retrieve stock stats indicators for a given ticker symbol and indicator.
    Args:
        symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
        indicator (str): Technical indicator to get the analysis and report of
        curr_date (str): The current trading date you are trading on, YYYY-mm-dd
        look_back_days (int): How many days to look back, default is 30
    Returns:
        str: A formatted dataframe containing the stock stats indicators
        for the specified ticker symbol and indicator.
    """

    today_date = pd.Timestamp.today()

    end_date = today_date
    start_date = today_date - pd.DateOffset(years=15)
    start_date = start_date.strftime("%Y-%m-%d")
    end_date = end_date.strftime("%Y-%m-%d")

    data_file = await cache_market_data(symbol, start_date, end_date)
    local_data_file = await data_file.download()

    result_stockstats = interface.get_stock_stats_indicators_window(
        symbol, indicator, curr_date, look_back_days, False, local_data_file
    )

    return result_stockstats

# {{docs-fragment get_stockstats_indicators_report_online}}
@env.task
async def get_stockstats_indicators_report_online(
    symbol: str,  # ticker symbol of the company
    indicator: str,  # technical indicator to get the analysis and report of
    curr_date: str,  # The current trading date you are trading on, YYYY-mm-dd"
    look_back_days: int = 30,  # "how many days to look back"
) -> str:
    """
    Retrieve stock stats indicators for a given ticker symbol and indicator.
    Args:
        symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
        indicator (str): Technical indicator to get the analysis and report of
        curr_date (str): The current trading date you are trading on, YYYY-mm-dd
        look_back_days (int): How many days to look back, default is 30
    Returns:
        str: A formatted dataframe containing the stock stats indicators
        for the specified ticker symbol and indicator.
    """

    today_date = pd.Timestamp.today()

    end_date = today_date
    start_date = today_date - pd.DateOffset(years=15)
    start_date = start_date.strftime("%Y-%m-%d")
    end_date = end_date.strftime("%Y-%m-%d")

    data_file = await cache_market_data(symbol, start_date, end_date)
    local_data_file = await data_file.download()

    result_stockstats = interface.get_stock_stats_indicators_window(
        symbol, indicator, curr_date, look_back_days, True, local_data_file
    )

    return result_stockstats

# {{/docs-fragment get_stockstats_indicators_report_online}}

@env.task
async def get_finnhub_company_insider_sentiment(
    ticker: str,  # ticker symbol for the company
    curr_date: str,  # current date of you are trading at, yyyy-mm-dd
) -> str:
    """
    Retrieve insider sentiment information about a company (retrieved
    from public SEC information) for the past 30 days
    Args:
        ticker (str): ticker symbol of the company
        curr_date (str): current date you are trading at, yyyy-mm-dd
    Returns:
        str: a report of the sentiment in the past 30 days starting at curr_date
    """

    data_sentiment = interface.get_finnhub_company_insider_sentiment(
        ticker, curr_date, 30
    )

    return data_sentiment

@env.task
async def get_finnhub_company_insider_transactions(
    ticker: str,  # ticker symbol
    curr_date: str,  # current date you are trading at, yyyy-mm-dd
) -> str:
    """
    Retrieve insider transaction information about a company
    (retrieved from public SEC information) for the past 30 days
    Args:
        ticker (str): ticker symbol of the company
        curr_date (str): current date you are trading at, yyyy-mm-dd
    Returns:
        str: a report of the company's insider transactions/trading information in the past 30 days
    """

    data_trans = interface.get_finnhub_company_insider_transactions(
        ticker, curr_date, 30
    )

    return data_trans

@env.task
async def get_simfin_balance_sheet(
    ticker: str,  # ticker symbol
    freq: str,  # reporting frequency of the company's financial history: annual/quarterly
    curr_date: str,  # current date you are trading at, yyyy-mm-dd
):
    """
    Retrieve the most recent balance sheet of a company
    Args:
        ticker (str): ticker symbol of the company
        freq (str): reporting frequency of the company's financial history: annual / quarterly
        curr_date (str): current date you are trading at, yyyy-mm-dd
    Returns:
        str: a report of the company's most recent balance sheet
    """

    data_balance_sheet = interface.get_simfin_balance_sheet(ticker, freq, curr_date)

    return data_balance_sheet

@env.task
async def get_simfin_cashflow(
    ticker: str,  # ticker symbol
    freq: str,  # reporting frequency of the company's financial history: annual/quarterly
    curr_date: str,  # current date you are trading at, yyyy-mm-dd
) -> str:
    """
    Retrieve the most recent cash flow statement of a company
    Args:
        ticker (str): ticker symbol of the company
        freq (str): reporting frequency of the company's financial history: annual / quarterly
        curr_date (str): current date you are trading at, yyyy-mm-dd
    Returns:
            str: a report of the company's most recent cash flow statement
    """

    data_cashflow = interface.get_simfin_cashflow(ticker, freq, curr_date)

    return data_cashflow

@env.task
async def get_simfin_income_stmt(
    ticker: str,  # ticker symbol
    freq: str,  # reporting frequency of the company's financial history: annual/quarterly
    curr_date: str,  # current date you are trading at, yyyy-mm-dd
) -> str:
    """
    Retrieve the most recent income statement of a company
    Args:
        ticker (str): ticker symbol of the company
        freq (str): reporting frequency of the company's financial history: annual / quarterly
        curr_date (str): current date you are trading at, yyyy-mm-dd
    Returns:
            str: a report of the company's most recent income statement
    """

    data_income_stmt = interface.get_simfin_income_statements(ticker, freq, curr_date)

    return data_income_stmt

@env.task
async def get_google_news(
    query: str,  # Query to search with
    curr_date: str,  # Curr date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the latest news from Google News based on a query and date range.
    Args:
        query (str): Query to search with
        curr_date (str): Current date in yyyy-mm-dd format
        look_back_days (int): How many days to look back
    Returns:
        str: A formatted string containing the latest news from Google News
        based on the query and date range.
    """

    google_news_results = interface.get_google_news(query, curr_date, 7)

    return google_news_results

@env.task
async def get_stock_news_openai(
    ticker: str,  # the company's ticker
    curr_date: str,  # Current date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the latest news about a given stock by using OpenAI's news API.
    Args:
        ticker (str): Ticker of a company. e.g. AAPL, TSM
        curr_date (str): Current date in yyyy-mm-dd format
    Returns:
        str: A formatted string containing the latest news about the company on the given date.
    """

    openai_news_results = interface.get_stock_news_openai(ticker, curr_date)

    return openai_news_results

@env.task
async def get_global_news_openai(
    curr_date: str,  # Current date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the latest macroeconomics news on a given date using OpenAI's macroeconomics news API.
    Args:
        curr_date (str): Current date in yyyy-mm-dd format
    Returns:
        str: A formatted string containing the latest macroeconomic news on the given date.
    """

    openai_news_results = interface.get_global_news_openai(curr_date)

    return openai_news_results

@env.task
async def get_fundamentals_openai(
    ticker: str,  # the company's ticker
    curr_date: str,  # Current date in yyyy-mm-dd format
) -> str:
    """
    Retrieve the latest fundamental information about a given stock
    on a given date by using OpenAI's news API.
    Args:
        ticker (str): Ticker of a company. e.g. AAPL, TSM
        curr_date (str): Current date in yyyy-mm-dd format

    Returns:
        str: A formatted string containing the latest fundamental information
        about the company on the given date.
    """

    openai_fundamentals_results = interface.get_fundamentals_openai(ticker, curr_date)

    return openai_fundamentals_results
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/tools/toolkit.py*

When initialized, an analyst enters a structured reasoning loop (via LangChain), where it can call tools, observe outputs, and refine its internal state before generating a final report. These reports are later consumed by downstream agents.

Here's an example of a news analyst that interprets global events and macroeconomic signals. We specify the tools accessible to the analyst, and the LLM selects which ones to use based on context.

```
import asyncio

from agents.utils.utils import AgentState
from flyte_env import env
from langchain_core.messages import ToolMessage, convert_to_openai_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from tools import toolkit

import flyte

MAX_ITERATIONS = 5

# {{docs-fragment agent_helper}}
async def run_chain_with_tools(
    type: str, state: AgentState, llm: str, system_message: str, tool_names: list[str]
) -> AgentState:
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are a helpful AI assistant, collaborating with other assistants."
                " Use the provided tools to progress towards answering the question."
                " If you are unable to fully answer, that's OK; another assistant with different tools"
                " will help where you left off. Execute what you can to make progress."
                " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
                " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
                " You have access to the following tools: {tool_names}.\n{system_message}"
                " For your reference, the current date is {current_date}. The company we want to look at is {ticker}.",
            ),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )

    prompt = prompt.partial(system_message=system_message)
    prompt = prompt.partial(tool_names=", ".join(tool_names))
    prompt = prompt.partial(current_date=state.trade_date)
    prompt = prompt.partial(ticker=state.company_of_interest)

    chain = prompt | ChatOpenAI(model=llm).bind_tools(
        [getattr(toolkit, tool_name).func for tool_name in tool_names]
    )

    iteration = 0
    while iteration < MAX_ITERATIONS:
        result = await chain.ainvoke(state.messages)
        state.messages.append(convert_to_openai_messages(result))

        if not result.tool_calls:
            # Final response — no tools required
            setattr(state, f"{type}_report", result.content or "")
            break

        # Run all tool calls in parallel
        async def run_single_tool(tool_call):
            tool_name = tool_call["name"]
            tool_args = tool_call["args"]
            tool = getattr(toolkit, tool_name, None)

            if not tool:
                return None

            content = await tool(**tool_args)
            return ToolMessage(
                tool_call_id=tool_call["id"], name=tool_name, content=content
            )

        with flyte.group(f"tool_calls_iteration_{iteration}"):
            tool_messages = await asyncio.gather(
                *[run_single_tool(tc) for tc in result.tool_calls]
            )

        # Add valid tool results to state
        tool_messages = [msg for msg in tool_messages if msg]
        state.messages.extend(convert_to_openai_messages(tool_messages))

        iteration += 1
    else:
        # Reached iteration cap — optionally raise or log
        print(f"Max iterations ({MAX_ITERATIONS}) reached for {type}")

    return state

# {{/docs-fragment agent_helper}}

@env.task
async def create_fundamentals_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [toolkit.get_fundamentals_openai]
    else:
        tools = [
            toolkit.get_finnhub_company_insider_sentiment,
            toolkit.get_finnhub_company_insider_transactions,
            toolkit.get_simfin_balance_sheet,
            toolkit.get_simfin_cashflow,
            toolkit.get_simfin_income_stmt,
        ]

    system_message = (
        "You are a researcher tasked with analyzing fundamental information over the past week about a company. "
        "Please write a comprehensive report of the company's fundamental information such as financial documents, "
        "company profile, basic company financials, company financial history, insider sentiment, and insider "
        "transactions to gain a full view of the company's "
        "fundamental information to inform traders. Make sure to include as much detail as possible. "
        "Do not simply state the trends are mixed, "
        "provide detailed and finegrained analysis and insights that may help traders make decisions. "
        "Make sure to append a Markdown table at the end of the report to organize key points in the report, "
        "organized and easy to read."
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools(
        "fundamentals", state, llm, system_message, tool_names
    )

@env.task
async def create_market_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [
            toolkit.get_YFin_data_online,
            toolkit.get_stockstats_indicators_report_online,
        ]
    else:
        tools = [
            toolkit.get_YFin_data,
            toolkit.get_stockstats_indicators_report,
        ]

    system_message = (
        """You are a trading assistant tasked with analyzing financial markets.
Your role is to select the **most relevant indicators** for a given market condition
or trading strategy from the following list.
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
Categories and each category's indicators are:

Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator.
Usage: Identify trend direction and serve as dynamic support/resistance.
Tips: It lags price; combine with faster indicators for timely signals.
- close_200_sma: 200 SMA: A long-term trend benchmark.
Usage: Confirm overall market trend and identify golden/death cross setups.
Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
- close_10_ema: 10 EMA: A responsive short-term average.
Usage: Capture quick shifts in momentum and potential entry points.
Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.

MACD Related:
- macd: MACD: Computes momentum via differences of EMAs.
Usage: Look for crossovers and divergence as signals of trend changes.
Tips: Confirm with other indicators in low-volatility or sideways markets.
- macds: MACD Signal: An EMA smoothing of the MACD line.
Usage: Use crossovers with the MACD line to trigger trades.
Tips: Should be part of a broader strategy to avoid false positives.
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal.
Usage: Visualize momentum strength and spot divergence early.
Tips: Can be volatile; complement with additional filters in fast-moving markets.

Momentum Indicators:
- rsi: RSI: Measures momentum to flag overbought/oversold conditions.
Usage: Apply 70/30 thresholds and watch for divergence to signal reversals.
Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.

Volatility Indicators:
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands.
Usage: Acts as a dynamic benchmark for price movement.
Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line.
Usage: Signals potential overbought conditions and breakout zones.
Tips: Confirm signals with other tools; prices may ride the band in strong trends.
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line.
Usage: Indicates potential oversold conditions.
Tips: Use additional analysis to avoid false reversal signals.
- atr: ATR: Averages true range to measure volatility.
Usage: Set stop-loss levels and adjust position sizes based on current market volatility.
Tips: It's a reactive measure, so use it as part of a broader risk management strategy.

Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume.
Usage: Confirm trends by integrating price action with volume data.
Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.

- Select indicators that provide diverse and complementary information.
Avoid redundancy (e.g., do not select both rsi and stochrsi).
Also briefly explain why they are suitable for the given market context.
When you tool call, please use the exact name of the indicators provided above as they are defined parameters,
otherwise your call will fail.
Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators.
Write a very detailed and nuanced report of the trends you observe.
Do not simply state the trends are mixed, provide detailed and finegrained analysis
and insights that may help traders make decisions."""
        """ Make sure to append a Markdown table at the end of the report to
        organize key points in the report, organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]
    return await run_chain_with_tools("market", state, llm, system_message, tool_names)

# {{docs-fragment news_analyst}}
@env.task
async def create_news_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [
            toolkit.get_global_news_openai,
            toolkit.get_google_news,
        ]
    else:
        tools = [
            toolkit.get_finnhub_news,
            toolkit.get_reddit_news,
            toolkit.get_google_news,
        ]

    system_message = (
        "You are a news researcher tasked with analyzing recent news and trends over the past week. "
        "Please write a comprehensive report of the current state of the world that is relevant for "
        "trading and macroeconomics. "
        "Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, "
        "provide detailed and finegrained analysis and insights that may help traders make decisions."
        """ Make sure to append a Markdown table at the end of the report to organize key points in the report,
        organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools("news", state, llm, system_message, tool_names)

# {{/docs-fragment news_analyst}}

@env.task
async def create_social_media_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [toolkit.get_stock_news_openai]
    else:
        tools = [toolkit.get_reddit_stock_info]

    system_message = (
        "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, "
        "recent company news, and public sentiment for a specific company over the past week. "
        "You will be given a company's name your objective is to write a comprehensive long report "
        "detailing your analysis, insights, and implications for traders and investors on this company's current state "
        "after looking at social media and what people are saying about that company, "
        "analyzing sentiment data of what people feel each day about the company, and looking at recent company news. "
        "Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends "
        "are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
        """ Make sure to append a Makrdown table at the end of the report to organize key points in the report,
        organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools(
        "sentiment", state, llm, system_message, tool_names
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/analysts.py*

Each analyst agent uses a helper function to bind tools, iterate through reasoning steps (up to a configurable maximum), and produce an answer. Setting a max iteration count is crucial to prevent runaway loops. As agents reason, their message history is preserved in their internal state and passed along to the next agent in the chain.

```
import asyncio

from agents.utils.utils import AgentState
from flyte_env import env
from langchain_core.messages import ToolMessage, convert_to_openai_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from tools import toolkit

import flyte

MAX_ITERATIONS = 5

# {{docs-fragment agent_helper}}
async def run_chain_with_tools(
    type: str, state: AgentState, llm: str, system_message: str, tool_names: list[str]
) -> AgentState:
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are a helpful AI assistant, collaborating with other assistants."
                " Use the provided tools to progress towards answering the question."
                " If you are unable to fully answer, that's OK; another assistant with different tools"
                " will help where you left off. Execute what you can to make progress."
                " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
                " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
                " You have access to the following tools: {tool_names}.\n{system_message}"
                " For your reference, the current date is {current_date}. The company we want to look at is {ticker}.",
            ),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )

    prompt = prompt.partial(system_message=system_message)
    prompt = prompt.partial(tool_names=", ".join(tool_names))
    prompt = prompt.partial(current_date=state.trade_date)
    prompt = prompt.partial(ticker=state.company_of_interest)

    chain = prompt | ChatOpenAI(model=llm).bind_tools(
        [getattr(toolkit, tool_name).func for tool_name in tool_names]
    )

    iteration = 0
    while iteration < MAX_ITERATIONS:
        result = await chain.ainvoke(state.messages)
        state.messages.append(convert_to_openai_messages(result))

        if not result.tool_calls:
            # Final response — no tools required
            setattr(state, f"{type}_report", result.content or "")
            break

        # Run all tool calls in parallel
        async def run_single_tool(tool_call):
            tool_name = tool_call["name"]
            tool_args = tool_call["args"]
            tool = getattr(toolkit, tool_name, None)

            if not tool:
                return None

            content = await tool(**tool_args)
            return ToolMessage(
                tool_call_id=tool_call["id"], name=tool_name, content=content
            )

        with flyte.group(f"tool_calls_iteration_{iteration}"):
            tool_messages = await asyncio.gather(
                *[run_single_tool(tc) for tc in result.tool_calls]
            )

        # Add valid tool results to state
        tool_messages = [msg for msg in tool_messages if msg]
        state.messages.extend(convert_to_openai_messages(tool_messages))

        iteration += 1
    else:
        # Reached iteration cap — optionally raise or log
        print(f"Max iterations ({MAX_ITERATIONS}) reached for {type}")

    return state

# {{/docs-fragment agent_helper}}

@env.task
async def create_fundamentals_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [toolkit.get_fundamentals_openai]
    else:
        tools = [
            toolkit.get_finnhub_company_insider_sentiment,
            toolkit.get_finnhub_company_insider_transactions,
            toolkit.get_simfin_balance_sheet,
            toolkit.get_simfin_cashflow,
            toolkit.get_simfin_income_stmt,
        ]

    system_message = (
        "You are a researcher tasked with analyzing fundamental information over the past week about a company. "
        "Please write a comprehensive report of the company's fundamental information such as financial documents, "
        "company profile, basic company financials, company financial history, insider sentiment, and insider "
        "transactions to gain a full view of the company's "
        "fundamental information to inform traders. Make sure to include as much detail as possible. "
        "Do not simply state the trends are mixed, "
        "provide detailed and finegrained analysis and insights that may help traders make decisions. "
        "Make sure to append a Markdown table at the end of the report to organize key points in the report, "
        "organized and easy to read."
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools(
        "fundamentals", state, llm, system_message, tool_names
    )

@env.task
async def create_market_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [
            toolkit.get_YFin_data_online,
            toolkit.get_stockstats_indicators_report_online,
        ]
    else:
        tools = [
            toolkit.get_YFin_data,
            toolkit.get_stockstats_indicators_report,
        ]

    system_message = (
        """You are a trading assistant tasked with analyzing financial markets.
Your role is to select the **most relevant indicators** for a given market condition
or trading strategy from the following list.
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
Categories and each category's indicators are:

Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator.
Usage: Identify trend direction and serve as dynamic support/resistance.
Tips: It lags price; combine with faster indicators for timely signals.
- close_200_sma: 200 SMA: A long-term trend benchmark.
Usage: Confirm overall market trend and identify golden/death cross setups.
Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
- close_10_ema: 10 EMA: A responsive short-term average.
Usage: Capture quick shifts in momentum and potential entry points.
Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.

MACD Related:
- macd: MACD: Computes momentum via differences of EMAs.
Usage: Look for crossovers and divergence as signals of trend changes.
Tips: Confirm with other indicators in low-volatility or sideways markets.
- macds: MACD Signal: An EMA smoothing of the MACD line.
Usage: Use crossovers with the MACD line to trigger trades.
Tips: Should be part of a broader strategy to avoid false positives.
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal.
Usage: Visualize momentum strength and spot divergence early.
Tips: Can be volatile; complement with additional filters in fast-moving markets.

Momentum Indicators:
- rsi: RSI: Measures momentum to flag overbought/oversold conditions.
Usage: Apply 70/30 thresholds and watch for divergence to signal reversals.
Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.

Volatility Indicators:
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands.
Usage: Acts as a dynamic benchmark for price movement.
Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line.
Usage: Signals potential overbought conditions and breakout zones.
Tips: Confirm signals with other tools; prices may ride the band in strong trends.
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line.
Usage: Indicates potential oversold conditions.
Tips: Use additional analysis to avoid false reversal signals.
- atr: ATR: Averages true range to measure volatility.
Usage: Set stop-loss levels and adjust position sizes based on current market volatility.
Tips: It's a reactive measure, so use it as part of a broader risk management strategy.

Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume.
Usage: Confirm trends by integrating price action with volume data.
Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.

- Select indicators that provide diverse and complementary information.
Avoid redundancy (e.g., do not select both rsi and stochrsi).
Also briefly explain why they are suitable for the given market context.
When you tool call, please use the exact name of the indicators provided above as they are defined parameters,
otherwise your call will fail.
Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators.
Write a very detailed and nuanced report of the trends you observe.
Do not simply state the trends are mixed, provide detailed and finegrained analysis
and insights that may help traders make decisions."""
        """ Make sure to append a Markdown table at the end of the report to
        organize key points in the report, organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]
    return await run_chain_with_tools("market", state, llm, system_message, tool_names)

# {{docs-fragment news_analyst}}
@env.task
async def create_news_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [
            toolkit.get_global_news_openai,
            toolkit.get_google_news,
        ]
    else:
        tools = [
            toolkit.get_finnhub_news,
            toolkit.get_reddit_news,
            toolkit.get_google_news,
        ]

    system_message = (
        "You are a news researcher tasked with analyzing recent news and trends over the past week. "
        "Please write a comprehensive report of the current state of the world that is relevant for "
        "trading and macroeconomics. "
        "Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, "
        "provide detailed and finegrained analysis and insights that may help traders make decisions."
        """ Make sure to append a Markdown table at the end of the report to organize key points in the report,
        organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools("news", state, llm, system_message, tool_names)

# {{/docs-fragment news_analyst}}

@env.task
async def create_social_media_analyst(
    llm: str, state: AgentState, online_tools: bool
) -> AgentState:
    if online_tools:
        tools = [toolkit.get_stock_news_openai]
    else:
        tools = [toolkit.get_reddit_stock_info]

    system_message = (
        "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, "
        "recent company news, and public sentiment for a specific company over the past week. "
        "You will be given a company's name your objective is to write a comprehensive long report "
        "detailing your analysis, insights, and implications for traders and investors on this company's current state "
        "after looking at social media and what people are saying about that company, "
        "analyzing sentiment data of what people feel each day about the company, and looking at recent company news. "
        "Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends "
        "are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
        """ Make sure to append a Makrdown table at the end of the report to organize key points in the report,
        organized and easy to read."""
    )

    tool_names = [tool.func.__name__ for tool in tools]

    return await run_chain_with_tools(
        "sentiment", state, llm, system_message, tool_names
    )
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/analysts.py*

Once all analyst reports are complete, their outputs are collected and passed to the next stage of the workflow.

### Research agents

The research phase consists of two agents: a bullish researcher and a bearish one. They evaluate the company from opposing viewpoints, drawing on the analysts' reports. Unlike analysts, they don't use tools. Their role is to interpret, critique, and develop positions based on the evidence.

```
from agents.utils.utils import AgentState, InvestmentDebateState, memory_init
from flyte_env import env
from langchain_openai import ChatOpenAI

# {{docs-fragment bear_researcher}}
@env.task
async def create_bear_researcher(llm: str, state: AgentState) -> AgentState:
    investment_debate_state = state.investment_debate_state
    history = investment_debate_state.history
    bear_history = investment_debate_state.bear_history

    current_response = investment_debate_state.current_response
    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    memory = await memory_init(name="bear-researcher")

    curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
    past_memories = memory.get_memories(curr_situation, n_matches=2)

    past_memory_str = ""
    for rec in past_memories:
        past_memory_str += rec["recommendation"] + "\n\n"

    prompt = f"""You are a Bear Analyst making the case against investing in the stock.
Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators.
Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.

Key points to focus on:

- Risks and Challenges: Highlight factors like market saturation, financial instability,
or macroeconomic threats that could hinder the stock's performance.
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation,
or threats from competitors.
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning,
exposing weaknesses or over-optimistic assumptions.
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points
and debating effectively rather than simply listing facts.

Resources available:

Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bull argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate
that demonstrates the risks and weaknesses of investing in the stock.
You must also address reflections and learn from lessons and mistakes you made in the past.
"""

    response = ChatOpenAI(model=llm).invoke(prompt)

    argument = f"Bear Analyst: {response.content}"

    new_investment_debate_state = InvestmentDebateState(
        history=history + "\n" + argument,
        bear_history=bear_history + "\n" + argument,
        bull_history=investment_debate_state.bull_history,
        current_response=argument,
        count=investment_debate_state.count + 1,
    )

    state.investment_debate_state = new_investment_debate_state
    return state

# {{/docs-fragment bear_researcher}}

@env.task
async def create_bull_researcher(llm: str, state: AgentState) -> AgentState:
    investment_debate_state = state.investment_debate_state
    history = investment_debate_state.history
    bull_history = investment_debate_state.bull_history

    current_response = investment_debate_state.current_response
    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    memory = await memory_init(name="bull-researcher")

    curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
    past_memories = memory.get_memories(curr_situation, n_matches=2)

    past_memory_str = ""
    for rec in past_memories:
        past_memory_str += rec["recommendation"] + "\n\n"

    prompt = f"""You are a Bull Analyst advocating for investing in the stock.
Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages,
and positive market indicators.
Leverage the provided research and data to address concerns and counter bearish arguments effectively.

Key points to focus on:
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing
concerns thoroughly and showing why the bull perspective holds stronger merit.
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points
and debating effectively rather than just listing data.

Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bear argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate
that demonstrates the strengths of the bull position.
You must also address reflections and learn from lessons and mistakes you made in the past.
"""

    response = ChatOpenAI(model=llm).invoke(prompt)

    argument = f"Bull Analyst: {response.content}"

    new_investment_debate_state = InvestmentDebateState(
        history=history + "\n" + argument,
        bull_history=bull_history + "\n" + argument,
        bear_history=investment_debate_state.bear_history,
        current_response=argument,
        count=investment_debate_state.count + 1,
    )

    state.investment_debate_state = new_investment_debate_state
    return state
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/researchers.py*

To aid reasoning, the agents can also retrieve relevant "memories" from a vector database, giving them richer historical context. The number of debate rounds is configurable, and after a few iterations of back-and-forth between the bull and bear, a research manager agent reviews their arguments and makes a final investment decision.

```
from agents.utils.utils import (
    AgentState,
    InvestmentDebateState,
    RiskDebateState,
    memory_init,
)
from flyte_env import env
from langchain_openai import ChatOpenAI

# {{docs-fragment research_manager}}
@env.task
async def create_research_manager(llm: str, state: AgentState) -> AgentState:
    history = state.investment_debate_state.history
    investment_debate_state = state.investment_debate_state
    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    memory = await memory_init(name="research-manager")

    curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
    past_memories = memory.get_memories(curr_situation, n_matches=2)

    past_memory_str = ""
    for rec in past_memories:
        past_memory_str += rec["recommendation"] + "\n\n"

    prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate
this round of debate and make a definitive decision:
align with the bear analyst, the bull analyst,
or choose Hold only if it is strongly justified based on the arguments presented.

Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning.
Your recommendation—Buy, Sell, or Hold—must be clear and actionable.
Avoid defaulting to Hold simply because both sides have valid points;
commit to a stance grounded in the debate's strongest arguments.

Additionally, develop a detailed investment plan for the trader. This should include:

Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations.
Use these insights to refine your decision-making and ensure you are learning and improving.
Present your analysis conversationally, as if speaking naturally, without special formatting.

Here are your past reflections on mistakes:
\"{past_memory_str}\"

Here is the debate:
Debate History:
{history}"""
    response = ChatOpenAI(model=llm).invoke(prompt)

    new_investment_debate_state = InvestmentDebateState(
        judge_decision=response.content,
        history=investment_debate_state.history,
        bear_history=investment_debate_state.bear_history,
        bull_history=investment_debate_state.bull_history,
        current_response=response.content,
        count=investment_debate_state.count,
    )

    state.investment_debate_state = new_investment_debate_state
    state.investment_plan = response.content

    return state

# {{/docs-fragment research_manager}}

@env.task
async def create_risk_manager(llm: str, state: AgentState) -> AgentState:
    history = state.risk_debate_state.history
    risk_debate_state = state.risk_debate_state
    trader_plan = state.investment_plan
    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    memory = await memory_init(name="risk-manager")

    curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
    past_memories = memory.get_memories(curr_situation, n_matches=2)

    past_memory_str = ""
    for rec in past_memories:
        past_memory_str += rec["recommendation"] + "\n\n"

    prompt = f"""As the Risk Management Judge and Debate Facilitator,
your goal is to evaluate the debate between three risk analysts—Risky,
Neutral, and Safe/Conservative—and determine the best course of action for the trader.
Your decision must result in a clear recommendation: Buy, Sell, or Hold.
Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid.
Strive for clarity and decisiveness.

Guidelines for Decision-Making:
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**,
and adjust it based on the analysts' insights.
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments
and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.

Deliverables:
- A clear and actionable recommendation: Buy, Sell, or Hold.
- Detailed reasoning anchored in the debate and past reflections.

---

**Analysts Debate History:**
{history}

---

Focus on actionable insights and continuous improvement.
Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""

    response = ChatOpenAI(model=llm).invoke(prompt)

    new_risk_debate_state = RiskDebateState(
        judge_decision=response.content,
        history=risk_debate_state.history,
        risky_history=risk_debate_state.risky_history,
        safe_history=risk_debate_state.safe_history,
        neutral_history=risk_debate_state.neutral_history,
        latest_speaker="Judge",
        current_risky_response=risk_debate_state.current_risky_response,
        current_safe_response=risk_debate_state.current_safe_response,
        current_neutral_response=risk_debate_state.current_neutral_response,
        count=risk_debate_state.count,
    )

    state.risk_debate_state = new_risk_debate_state
    state.final_trade_decision = response.content

    return state
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/managers.py*

### Trading agent

The trader agent consolidates the insights from analysts and researchers to generate a final recommendation. It synthesizes competing signals and produces a conclusion such as _Buy for long-term growth despite short-term volatility_.

```
from agents.utils.utils import AgentState, memory_init
from flyte_env import env
from langchain_core.messages import convert_to_openai_messages
from langchain_openai import ChatOpenAI

# {{docs-fragment trader}}
@env.task
async def create_trader(llm: str, state: AgentState) -> AgentState:
    company_name = state.company_of_interest
    investment_plan = state.investment_plan
    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    memory = await memory_init(name="trader")

    curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
    past_memories = memory.get_memories(curr_situation, n_matches=2)

    past_memory_str = ""
    for rec in past_memories:
        past_memory_str += rec["recommendation"] + "\n\n"

    context = {
        "role": "user",
        "content": f"Based on a comprehensive analysis by a team of analysts, "
        f"here is an investment plan tailored for {company_name}. "
        "This plan incorporates insights from current technical market trends, "
        "macroeconomic indicators, and social media sentiment. "
        "Use this plan as a foundation for evaluating your next trading decision.\n\n"
        f"Proposed Investment Plan: {investment_plan}\n\n"
        "Leverage these insights to make an informed and strategic decision.",
    }

    messages = [
        {
            "role": "system",
            "content": f"""You are a trading agent analyzing market data to make investment decisions.
Based on your analysis, provide a specific recommendation to buy, sell, or hold.
End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**'
to confirm your recommendation.
Do not forget to utilize lessons from past decisions to learn from your mistakes.
Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
        },
        context,
    ]

    result = ChatOpenAI(model=llm).invoke(messages)

    state.messages.append(convert_to_openai_messages(result))
    state.trader_investment_plan = result.content
    state.sender = "Trader"

    return state

# {{/docs-fragment trader}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/trader.py*

### Risk agents

Risk agents comprise agents with different risk tolerances: a risky debater, a neutral one, and a conservative one. They assess the portfolio through lenses like market volatility, liquidity, and systemic risk. Similar to the bull-bear debate, these agents engage in internal discussion, after which a risk manager makes the final call.

```
from agents.utils.utils import AgentState, RiskDebateState
from flyte_env import env
from langchain_openai import ChatOpenAI

# {{docs-fragment risk_debator}}
@env.task
async def create_risky_debator(llm: str, state: AgentState) -> AgentState:
    risk_debate_state = state.risk_debate_state
    history = risk_debate_state.history
    risky_history = risk_debate_state.risky_history

    current_safe_response = risk_debate_state.current_safe_response
    current_neutral_response = risk_debate_state.current_neutral_response

    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    trader_decision = state.trader_investment_plan

    prompt = f"""As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities,
emphasizing bold strategies and competitive advantages.
When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential,
and innovative benefits—even when these come with elevated risk.
Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views.
Specifically, respond directly to each point made by the conservative and neutral analysts,
countering with data-driven rebuttals and persuasive reasoning.
Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative.
Here is the trader's decision:

{trader_decision}

Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative
and neutral stances to demonstrate why your high-reward perspective offers the best path forward.
Incorporate insights from the following sources into your arguments:

Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here are the last arguments from the conservative analyst: {current_safe_response}
Here are the last arguments from the neutral analyst: {current_neutral_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.

Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic,
and asserting the benefits of risk-taking to outpace market norms.
Maintain a focus on debating and persuading, not just presenting data.
Challenge each counterpoint to underscore why a high-risk approach is optimal.
Output conversationally as if you are speaking without any special formatting."""

    response = ChatOpenAI(model=llm).invoke(prompt)

    argument = f"Risky Analyst: {response.content}"

    new_risk_debate_state = RiskDebateState(
        history=history + "\n" + argument,
        risky_history=risky_history + "\n" + argument,
        safe_history=risk_debate_state.safe_history,
        neutral_history=risk_debate_state.neutral_history,
        latest_speaker="Risky",
        current_risky_response=argument,
        current_safe_response=current_safe_response,
        current_neutral_response=current_neutral_response,
        count=risk_debate_state.count + 1,
    )

    state.risk_debate_state = new_risk_debate_state
    return state

# {{/docs-fragment risk_debator}}

@env.task
async def create_safe_debator(llm: str, state: AgentState) -> AgentState:
    risk_debate_state = state.risk_debate_state
    history = risk_debate_state.history
    safe_history = risk_debate_state.safe_history

    current_risky_response = risk_debate_state.current_risky_response
    current_neutral_response = risk_debate_state.current_neutral_response

    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    trader_decision = state.trader_investment_plan

    prompt = f"""As the Safe/Conservative Risk Analyst, your primary objective is to protect assets,
minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation,
carefully assessing potential losses, economic downturns, and market volatility.
When evaluating the trader's decision or plan, critically examine high-risk elements,
pointing out where the decision may expose the firm to undue risk and where more cautious
alternatives could secure long-term gains.
Here is the trader's decision:

{trader_decision}

Your task is to actively counter the arguments of the Risky and Neutral Analysts,
highlighting where their views may overlook potential threats or fail to prioritize sustainability.
Respond directly to their points, drawing from the following data sources
to build a convincing case for a low-risk approach adjustment to the trader's decision:

Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here is the last response from the risky analyst: {current_risky_response}
Here is the last response from the neutral analyst: {current_neutral_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.

Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked.
Address each of their counterpoints to showcase why a conservative stance is ultimately the
safest path for the firm's assets.
Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy
over their approaches.
Output conversationally as if you are speaking without any special formatting."""

    response = ChatOpenAI(model=llm).invoke(prompt)

    argument = f"Safe Analyst: {response.content}"

    new_risk_debate_state = RiskDebateState(
        history=history + "\n" + argument,
        risky_history=risk_debate_state.risky_history,
        safe_history=safe_history + "\n" + argument,
        neutral_history=risk_debate_state.neutral_history,
        latest_speaker="Safe",
        current_risky_response=current_risky_response,
        current_safe_response=argument,
        current_neutral_response=current_neutral_response,
        count=risk_debate_state.count + 1,
    )

    state.risk_debate_state = new_risk_debate_state
    return state

@env.task
async def create_neutral_debator(llm: str, state: AgentState) -> AgentState:
    risk_debate_state = state.risk_debate_state
    history = risk_debate_state.history
    neutral_history = risk_debate_state.neutral_history

    current_risky_response = risk_debate_state.current_risky_response
    current_safe_response = risk_debate_state.current_safe_response

    market_research_report = state.market_report
    sentiment_report = state.sentiment_report
    news_report = state.news_report
    fundamentals_report = state.fundamentals_report

    trader_decision = state.trader_investment_plan

    prompt = f"""As the Neutral Risk Analyst, your role is to provide a balanced perspective,
weighing both the potential benefits and risks of the trader's decision or plan.
You prioritize a well-rounded approach, evaluating the upsides
and downsides while factoring in broader market trends,
potential economic shifts, and diversification strategies.Here is the trader's decision:

{trader_decision}

Your task is to challenge both the Risky and Safe Analysts,
pointing out where each perspective may be overly optimistic or overly cautious.
Use insights from the following data sources to support a moderate, sustainable strategy
to adjust the trader's decision:

Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here is the last response from the risky analyst: {current_risky_response}
Here is the last response from the safe analyst: {current_safe_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.

Engage actively by analyzing both sides critically, addressing weaknesses in the risky
and conservative arguments to advocate for a more balanced approach.
Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds,
providing growth potential while safeguarding against extreme volatility.
Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to
the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""

    response = ChatOpenAI(model=llm).invoke(prompt)

    argument = f"Neutral Analyst: {response.content}"

    new_risk_debate_state = RiskDebateState(
        history=history + "\n" + argument,
        risky_history=risk_debate_state.risky_history,
        safe_history=risk_debate_state.safe_history,
        neutral_history=neutral_history + "\n" + argument,
        latest_speaker="Neutral",
        current_risky_response=current_risky_response,
        current_safe_response=current_safe_response,
        current_neutral_response=argument,
        count=risk_debate_state.count + 1,
    )

    state.risk_debate_state = new_risk_debate_state
    return state
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/risk_debators.py*

The outcome of the risk manager — whether to proceed with the trade or not — is considered the final decision of the trading simulation.

You can visualize this full pipeline in the Flyte/Union UI, where every step is logged.
You’ll see input/output metadata for each tool and agent task.
Thanks to Flyte's caching, repeated steps are skipped unless inputs change, saving time and compute resources.

### Retaining agent memory with S3 vectors

To help agents learn from past decisions, we persist their memory in a vector store. In this example, we use an [S3 vector](https://aws.amazon.com/s3/features/vectors/) bucket for their simplicity and tight integration with Flyte and Union, but any vector database can be used.

Note: To use the S3 vector store, make sure your IAM role has the following permissions configured:

```
s3vectors:CreateVectorBucket
s3vectors:CreateIndex
s3vectors:PutVectors
s3vectors:GetIndex
s3vectors:GetVectors
s3vectors:QueryVectors
s3vectors:GetVectorBucket
```

After each trade decision, you can run a `reflect_on_decisions` task. This evaluates whether the final outcome aligned with the agent's recommendation and stores that reflection in the vector store. These stored insights can later be retrieved to provide historical context and improve future decision-making.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.0.0b52",
#     "akshare==1.16.98",
#     "backtrader==1.9.78.123",
#     "boto3==1.39.9",
#     "chainlit==2.5.5",
#     "eodhd==1.0.32",
#     "feedparser==6.0.11",
#     "finnhub-python==2.4.23",
#     "langchain-experimental==0.3.4",
#     "langchain-openai==0.3.23",
#     "pandas==2.3.0",
#     "parsel==1.10.0",
#     "praw==7.8.1",
#     "pytz==2025.2",
#     "questionary==2.1.0",
#     "redis==6.2.0",
#     "requests==2.32.4",
#     "stockstats==0.6.5",
#     "tqdm==4.67.1",
#     "tushare==1.4.21",
#     "typing-extensions==4.14.0",
#     "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy

import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
    create_neutral_debator,
    create_risky_debator,
    create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
    reflect_bear_researcher,
    reflect_bull_researcher,
    reflect_research_manager,
    reflect_risk_manager,
    reflect_trader,
)

@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
    """Process a full trading signal to extract the core decision."""

    messages = [
        {
            "role": "system",
            "content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
        },
        {"role": "human", "content": full_signal},
    ]

    return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content

async def run_analyst(analyst_name, state, online_tools):
    # Create a copy of the state for isolation
    run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")

    # Run the analyst's chain
    result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)

    # Determine the report key
    report_key = (
        "sentiment_report"
        if analyst_name == "social_media"
        else f"{analyst_name}_report"
    )
    report_value = getattr(result_state, report_key)

    return result_state.messages[1:], report_key, report_value

# {{docs-fragment main}}
@env.task
async def main(
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
    if not selected_analysts:
        raise ValueError(
            "No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
        )

    state = AgentState(
        messages=[{"role": "human", "content": company_name}],
        company_of_interest=company_name,
        trade_date=str(trade_date),
    )

    # Run all analysts concurrently
    results = await asyncio.gather(
        *[
            run_analyst(analyst, deepcopy(state), online_tools)
            for analyst in selected_analysts
        ]
    )

    # Flatten and append all resulting messages into the shared state
    for messages, report_attr, report in results:
        state.messages.extend(messages)
        setattr(state, report_attr, report)

    # Bull/Bear debate loop
    state = await create_bull_researcher(QUICK_THINKING_LLM, state)  # Start with bull
    while state.investment_debate_state.count < 2 * max_debate_rounds:
        current = state.investment_debate_state.current_response
        if current.startswith("Bull"):
            state = await create_bear_researcher(QUICK_THINKING_LLM, state)
        else:
            state = await create_bull_researcher(QUICK_THINKING_LLM, state)

    state = await create_research_manager(DEEP_THINKING_LLM, state)
    state = await create_trader(QUICK_THINKING_LLM, state)

    # Risk debate loop
    state = await create_risky_debator(QUICK_THINKING_LLM, state)  # Start with risky
    while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
        speaker = state.risk_debate_state.latest_speaker
        if speaker == "Risky":
            state = await create_safe_debator(QUICK_THINKING_LLM, state)
        elif speaker == "Safe":
            state = await create_neutral_debator(QUICK_THINKING_LLM, state)
        else:
            state = await create_risky_debator(QUICK_THINKING_LLM, state)

    state = await create_risk_manager(DEEP_THINKING_LLM, state)
    decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)

    return decision, state

# {{/docs-fragment main}}

# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
    await asyncio.gather(
        reflect_bear_researcher(state, returns),
        reflect_bull_researcher(state, returns),
        reflect_trader(state, returns),
        reflect_risk_manager(state, returns),
        reflect_research_manager(state, returns),
    )

    return "Reflection completed."

# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
    returns: str,
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> str:
    _, state = await main(
        selected_analysts,
        max_debate_rounds,
        max_risk_discuss_rounds,
        online_tools,
        company_name,
        trade_date,
    )

    return await reflect_and_store(state, returns)

# {{/docs-fragment reflect_on_decisions}}

# {{docs-fragment execute_main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

    # run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
    # print(run.url)

# {{/docs-fragment execute_main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py*

### Running the simulation

First, set up your OpenAI secret (from [openai.com](https://platform.openai.com/api-keys)) and Finnhub API key (from [finnhub.io](https://finnhub.io/)):

```
flyte create secret openai_api_key <YOUR_OPENAI_API_KEY>
flyte create secret finnhub_api_key <YOUR_FINNHUB_API_KEY>
```

Then [clone the repo](https://github.com/unionai/unionai-examples), navigate to the `tutorials-v2/trading_agents` directory, and run the following commands:

```
flyte create config --endpoint <FLYTE_OR_UNION_ENDPOINT> --project <PROJECT_NAME> --domain <DOMAIN_NAME> --builder remote
uv run main.py
```

If you'd like to run the `reflect_on_decisions` task instead, comment out the `main` function call and uncomment the `reflect_on_decisions` call in the `__main__` block:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#     "flyte>=2.0.0b52",
#     "akshare==1.16.98",
#     "backtrader==1.9.78.123",
#     "boto3==1.39.9",
#     "chainlit==2.5.5",
#     "eodhd==1.0.32",
#     "feedparser==6.0.11",
#     "finnhub-python==2.4.23",
#     "langchain-experimental==0.3.4",
#     "langchain-openai==0.3.23",
#     "pandas==2.3.0",
#     "parsel==1.10.0",
#     "praw==7.8.1",
#     "pytz==2025.2",
#     "questionary==2.1.0",
#     "redis==6.2.0",
#     "requests==2.32.4",
#     "stockstats==0.6.5",
#     "tqdm==4.67.1",
#     "tushare==1.4.21",
#     "typing-extensions==4.14.0",
#     "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy

import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
    create_neutral_debator,
    create_risky_debator,
    create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
    reflect_bear_researcher,
    reflect_bull_researcher,
    reflect_research_manager,
    reflect_risk_manager,
    reflect_trader,
)

@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
    """Process a full trading signal to extract the core decision."""

    messages = [
        {
            "role": "system",
            "content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
        },
        {"role": "human", "content": full_signal},
    ]

    return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content

async def run_analyst(analyst_name, state, online_tools):
    # Create a copy of the state for isolation
    run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")

    # Run the analyst's chain
    result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)

    # Determine the report key
    report_key = (
        "sentiment_report"
        if analyst_name == "social_media"
        else f"{analyst_name}_report"
    )
    report_value = getattr(result_state, report_key)

    return result_state.messages[1:], report_key, report_value

# {{docs-fragment main}}
@env.task
async def main(
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
    if not selected_analysts:
        raise ValueError(
            "No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
        )

    state = AgentState(
        messages=[{"role": "human", "content": company_name}],
        company_of_interest=company_name,
        trade_date=str(trade_date),
    )

    # Run all analysts concurrently
    results = await asyncio.gather(
        *[
            run_analyst(analyst, deepcopy(state), online_tools)
            for analyst in selected_analysts
        ]
    )

    # Flatten and append all resulting messages into the shared state
    for messages, report_attr, report in results:
        state.messages.extend(messages)
        setattr(state, report_attr, report)

    # Bull/Bear debate loop
    state = await create_bull_researcher(QUICK_THINKING_LLM, state)  # Start with bull
    while state.investment_debate_state.count < 2 * max_debate_rounds:
        current = state.investment_debate_state.current_response
        if current.startswith("Bull"):
            state = await create_bear_researcher(QUICK_THINKING_LLM, state)
        else:
            state = await create_bull_researcher(QUICK_THINKING_LLM, state)

    state = await create_research_manager(DEEP_THINKING_LLM, state)
    state = await create_trader(QUICK_THINKING_LLM, state)

    # Risk debate loop
    state = await create_risky_debator(QUICK_THINKING_LLM, state)  # Start with risky
    while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
        speaker = state.risk_debate_state.latest_speaker
        if speaker == "Risky":
            state = await create_safe_debator(QUICK_THINKING_LLM, state)
        elif speaker == "Safe":
            state = await create_neutral_debator(QUICK_THINKING_LLM, state)
        else:
            state = await create_risky_debator(QUICK_THINKING_LLM, state)

    state = await create_risk_manager(DEEP_THINKING_LLM, state)
    decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)

    return decision, state

# {{/docs-fragment main}}

# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
    await asyncio.gather(
        reflect_bear_researcher(state, returns),
        reflect_bull_researcher(state, returns),
        reflect_trader(state, returns),
        reflect_risk_manager(state, returns),
        reflect_research_manager(state, returns),
    )

    return "Reflection completed."

# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
    returns: str,
    selected_analysts: list[str] = [
        "market",
        "fundamentals",
        "news",
        "social_media",
    ],
    max_debate_rounds: int = 1,
    max_risk_discuss_rounds: int = 1,
    online_tools: bool = True,
    company_name: str = "NVDA",
    trade_date: str = "2024-05-12",
) -> str:
    _, state = await main(
        selected_analysts,
        max_debate_rounds,
        max_risk_discuss_rounds,
        online_tools,
        company_name,
        trade_date,
    )

    return await reflect_and_store(state, returns)

# {{/docs-fragment reflect_on_decisions}}

# {{docs-fragment execute_main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

    # run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
    # print(run.url)

# {{/docs-fragment execute_main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py*

Then run:

```
uv run main.py
```

## Why Flyte? _(A quick note before you go)_

You might now be wondering: can't I just build all this with Python and LangChain?
Absolutely. But as your project grows, you'll likely run into these challenges:

1.  **Observability**: Agent workflows can feel opaque. You send a prompt, get a response, but what happened in between?

    - Were the right tools used?
    - Were correct arguments passed?
    - How did the LLM reason through intermediate steps?
    - Why did it fail?

    Flyte gives you a window into each of these stages.

2.  **Multi-agent coordination**: Real-world applications often require multiple agents with distinct roles and responsibilities. In such cases, you'll need:

    - Isolated state per agent,
    - Shared context where needed,
    - And coordination — sequential or parallel.

    Managing this manually gets fragile, fast. Flyte handles it for you.

3.  **Scalability**: Agents and tools might need to run in isolated or containerized environments. Whether you're scaling out to more agents or more powerful hardware, Flyte lets you scale without taxing your local machine or racking up unnecessary cloud bills.
4.  **Durability & recovery**: LLM-based workflows are often long-running and expensive. If something fails halfway:

    - Do you lose all progress?
    - Replay everything from scratch?

    With Flyte, you get built-in caching, checkpointing, and recovery, so you can resume where you left off.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/code-agent ===

# Run LLM-generated code

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/code_runner).

This example demonstrates how to run code generated by a large language model (LLM) using a `ContainerTask`.
The agent takes a user’s question, generates Flyte 2 code using the Flyte 2 documentation as context, and runs it in an isolated container.
If the execution fails, the agent reflects on the error and retries
up to a configurable limit until it succeeds.

Using `ContainerTask` ensures that all generated code runs in a secure environment.
This gives you full flexibility to execute arbitrary logic safely and reliably.

## What this example demonstrates

- How to combine LLM generation with programmatic execution.
- How to run untrusted or dynamically generated code securely.
- How to iteratively improve code using agent-like behavior.

## Setting up the agent environment

Let's start by importing the necessary libraries and setting up two environments: one for the container task and another for the agent task.
This example follows the `uv` script format to declare dependencies.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b23",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# ///
```

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

> [!NOTE]
> You can set up access to the OpenAI API using a Flyte secret.
>
> ```
> flyte create secret openai_api_key <YOUR_OPENAI_API_KEY>
> ```

We store the LLM-generated code in a structured format. This allows us to:

- Enforce consistent formatting
- Make debugging easier
- Log and analyze generations systematically

By capturing metadata alongside the raw code, we maintain transparency and make it easier to iterate or trace issues over time.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

We then define a state model to persist the agent's history across iterations. This includes previous messages,
generated code, and any errors encountered.

Maintaining this history allows the agent to reflect on past attempts, avoid repeating mistakes,
and iteratively improve the generated code.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

## Retrieve docs

We define a task to load documents from a given URL and concatenate them into a single string.
This string is then used as part of the LLM prompt.

We set `max_depth = 20` to avoid loading an excessive number of documents.
However, even with this limit, the resulting context can still be quite large.
To handle this, we use an LLM (GPT-4 in this case) that supports extended context windows.

> [!NOTE]
> Appending all documents into a single string can result in extremely large contexts, potentially exceeding the LLM’s token limit.
> If your dataset grows beyond what a single prompt can handle, there are a couple of strategies you can use.
> One option is to apply Retrieval-Augmented Generation (RAG), where you chunk the documents, embed them using a model,
> store the vectors in a vector database, and retrieve only the most relevant pieces at inference time.
>
> An alternative approach is to pass references to full files into the prompt, allowing the LLM to decide which files are most relevant based
> on natural-language search over file paths, summaries, or even contents. This method assumes that only a subset of files
> will be necessary for a given task, and the LLM is responsible for navigating the structure and identifying what to read.
> While this can be a lighter-weight solution for smaller datasets, its effectiveness depends on how well the LLM can
> reason over file references and the reliability of its internal search heuristics.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

## Code generation

Next, we define a utility function to construct the LLM chain responsible for generating Python code from user input. This chain leverages
a LangChain `PromptTemplate` to structure the input and an OpenAI chat model to generate well-formed, Flyte 2-compatible Python scripts.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

We then define a `generate` task responsible for producing the code solution.
To improve clarity and testability, the output is structured in three parts:
a short summary of the generated solution, a list of necessary imports,
and the main body of executable code.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

A `ContainerTask` then executes this code in an isolated container environment.
It takes the code as input, runs it safely, and returns the program’s output and exit code.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

This task verifies that the generated code runs as expected.
It tests the import statements first, then executes the full code.
It records the output and any error messages in the agent state for further analysis.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

If an error occurs, a separate task reflects on the failure and generates a response.
This reflection is added to the agent state to guide future iterations.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

Finally, we define a `main` task that runs the code agent and orchestrates the steps above.
If the code execution fails, we reflect on the error and retry until we reach the maximum number of iterations.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "langchain-core==0.3.66",
#    "langchain-openai==0.3.24",
#    "langchain-community==0.3.26",
#    "beautifulsoup4==4.13.4",
#    "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File

code_runner_task = ContainerTask(
    name="run_flyte_v2",
    image=flyte.Image.from_debian_base(),
    input_data_dir="/var/inputs",
    output_data_dir="/var/outputs",
    inputs={"script": File},
    outputs={"result": str, "exit_code": str},
    command=[
        "/bin/bash",
        "-c",
        (
            "set -o pipefail && "
            "uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
            "echo $? > /var/outputs/exit_code"
        ),
    ],
    resources=flyte.Resources(cpu=1, memory="1Gi"),
)

# {{/docs-fragment code_runner_task}}

# {{docs-fragment env}}
import tempfile
from typing import Optional

from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field

container_env = flyte.TaskEnvironment.from_task(
    "code-runner-container", code_runner_task
)

env = flyte.TaskEnvironment(
    name="code_runner",
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
    resources=flyte.Resources(cpu=1),
    depends_on=[container_env],
)

# {{/docs-fragment env}}

# {{docs-fragment code_base_model}}
class Code(BaseModel):
    """Schema for code solutions to questions about Flyte v2."""

    prefix: str = Field(
        default="", description="Description of the problem and approach"
    )
    imports: str = Field(
        default="", description="Code block with just import statements"
    )
    code: str = Field(
        default="", description="Code block not including import statements"
    )

# {{/docs-fragment code_base_model}}

# {{docs-fragment agent_state}}
class AgentState(BaseModel):
    messages: list[dict[str, str]] = Field(default_factory=list)
    generation: Code = Field(default_factory=Code)
    iterations: int = 0
    error: str = "no"
    output: Optional[str] = None

# {{/docs-fragment agent_state}}

# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI

    # Grader prompt
    code_gen_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.

Use the following pattern to execute the code:

<code>
if __name__ == "__main__":
    flyte.init_from_config()
    print(flyte.run(...))
</code>

Your response will be shown to the user.
Here is a full set of documentation:

-------
{context}
-------

Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
            ),
            ("placeholder", "{messages}"),
        ]
    )

    expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
    llm = ChatOpenAI(temperature=0, model=expt_llm)

    code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
    return code_gen_chain

# {{/docs-fragment generate_code_gen_chain}}

# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
    from bs4 import BeautifulSoup
    from langchain_community.document_loaders.recursive_url_loader import (
        RecursiveUrlLoader,
    )

    loader = RecursiveUrlLoader(
        url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
    )
    docs = loader.load()

    # Sort the list based on the URLs and get the text
    d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
    d_reversed = list(reversed(d_sorted))

    concatenated_content = "\n\n\n --- \n\n\n".join(
        [doc.page_content for doc in d_reversed]
    )
    return concatenated_content

# {{/docs-fragment docs_retriever}}

# {{docs-fragment generate}}
@env.task
async def generate(
    question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Generate a code solution

    Args:
        question (str): The user question
        state (dict): The current graph state
        concatenated_content (str): The concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, generation
    """

    print("---GENERATING CODE SOLUTION---")

    messages = state.messages
    iterations = state.iterations
    error = state.error

    # We have been routed back to generation with an error
    if error == "yes":
        messages += [
            {
                "role": "user",
                "content": (
                    "Now, try again. Invoke the code tool to structure the output "
                    "with a prefix, imports, and code block:"
                ),
            }
        ]

    code_gen_chain = await generate_code_gen_chain(debug)

    # Solution
    code_solution = code_gen_chain.invoke(
        {
            "context": concatenated_content,
            "messages": (
                messages if messages else [{"role": "user", "content": question}]
            ),
        }
    )

    messages += [
        {
            "role": "assistant",
            "content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        }
    ]

    return AgentState(
        messages=messages,
        generation=code_solution,
        iterations=iterations + 1,
        error=error,
        output=state.output,
    )

# {{/docs-fragment generate}}

# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print("---CHECKING CODE---")

    # State
    messages = state.messages
    code_solution = state.generation
    iterations = state.iterations

    # Get solution components
    imports = code_solution.imports.strip()
    code = code_solution.code.strip()

    # Create temp file for imports
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".py", delete=False
    ) as imports_file:
        imports_file.write(imports + "\n")
        imports_path = imports_file.name

    # Create temp file for code body
    with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
        code_file.write(imports + "\n" + code + "\n")
        code_path = code_file.name

    # Check imports
    import_output, import_exit_code = await code_runner_task(
        script=await File.from_local(imports_path)
    )

    if import_exit_code.strip() != "0":
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the import test: {import_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=import_output,
        )
    else:
        print("---CODE IMPORT CHECK: PASSED---")

    # Check execution
    code_output, code_exit_code = await code_runner_task(
        script=await File.from_local(code_path)
    )

    if code_exit_code.strip() != "0":
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [
            {
                "role": "user",
                "content": f"Your solution failed the code execution test: {code_output}",
            }
        ]
        messages += error_message
        return AgentState(
            generation=code_solution,
            messages=messages,
            iterations=iterations,
            error="yes",
            output=code_output,
        )
    else:
        print("---CODE BLOCK CHECK: PASSED---")

    # No errors
    print("---NO CODE TEST FAILURES---")

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error="no",
        output=code_output,
    )

# {{/docs-fragment code_check}}

# {{docs-fragment reflect}}
@env.task
async def reflect(
    state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
    """
    Reflect on errors

    Args:
        state (dict): The current graph state
        concatenated_content (str): Concatenated docs content
        debug (bool): Debug mode

    Returns:
        state (dict): New key added to state, reflection
    """

    print("---REFLECTING---")

    # State
    messages = state.messages
    iterations = state.iterations
    code_solution = state.generation

    # Prompt reflection
    code_gen_chain = await generate_code_gen_chain(debug)

    # Add reflection
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )

    messages += [
        {
            "role": "assistant",
            "content": f"Here are reflections on the error: {reflections}",
        }
    ]

    return AgentState(
        generation=code_solution,
        messages=messages,
        iterations=iterations,
        error=state.error,
        output=state.output,
    )

# {{/docs-fragment reflect}}

# {{docs-fragment main}}
@env.task
async def main(
    question: str = (
        "Define a two-task pattern where the second catches OOM from the first and retries with more memory."
    ),
    url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
    max_iterations: int = 3,
    debug: bool = False,
) -> str:
    concatenated_content = await docs_retriever(url=url)

    state: AgentState = AgentState()
    iterations = 0

    while True:
        with flyte.group(f"code-generation-pass-{iterations + 1}"):
            state = await generate(question, state, concatenated_content, debug)
            state = await code_check(state)

            error = state.error
            iterations = state.iterations

            if error == "no" or iterations >= max_iterations:
                print("---DECISION: FINISH---")
                code_solution = state.generation

                prefix = code_solution.prefix
                imports = code_solution.imports
                code = code_solution.code

                code_output = state.output

                return f"""{prefix}

{imports}
{code}

Result of code execution:
{code_output}
"""
            else:
                print("---DECISION: RE-TRY SOLUTION---")
                state = await reflect(state, concatenated_content, debug)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()

# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py*

## Running the code agent

If things are working properly, you should see output similar to the following:

```
---GENERATING CODE SOLUTION---
---CHECKING CODE---
---CODE BLOCK CHECK: PASSED---
---NO CODE TEST FAILURES---
---DECISION: FINISH---
In this solution, we define two tasks using Flyte v2.
The first task, `oomer`, is designed to simulate an out-of-memory (OOM) error by attempting to allocate a large list.
The second task, `failure_recovery`, attempts to execute `oomer` and catches any OOM errors.
If an OOM error is caught, it retries the `oomer` task with increased memory resources.
This pattern demonstrates how to handle resource-related exceptions and dynamically adjust task configurations in Flyte workflows.

import asyncio
import flyte
import flyte.errors
env = flyte.TaskEnvironment(name="oom_example", resources=flyte.Resources(cpu=1, memory="250Mi"))

@env.task
async def oomer(x: int):
    large_list = [0] * 100000000  # Simulate OOM
    print(len(large_list))

@env.task
async def always_succeeds() -> int:
    await asyncio.sleep(1)
    return 42

...
```

You can run the code agent on a Flyte/Union cluster using the following command:

```
uv run agent.py
```

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/text_to_sql ===

# Text-to-SQL

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/text_to_sql); based on work by [LlamaIndex](https://docs.llamaindex.ai/en/stable/examples/workflow/advanced_text_to_sql/).

Data analytics drives modern decision-making, but SQL often creates a bottleneck. Writing queries requires technical expertise, so non-technical stakeholders must rely on data teams. That translation layer slows everyone down.

Text-to-SQL narrows this gap by turning natural language into executable SQL queries. It lowers the barrier to structured data and makes databases accessible to more people.

In this tutorial, we build a Text-to-SQL workflow using LlamaIndex and evaluate it on the [WikiTableQuestions dataset](https://ppasupat.github.io/WikiTableQuestions/) (a benchmark of natural language questions over semi-structured tables). We then explore prompt optimization to see whether it improves accuracy and show how to track prompts and results over time. Along the way, we'll see what worked, what didn't, and what we learned about building durable evaluation pipelines. The pattern here can be adapted to your own datasets and workflows.

![Evaluation](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/text-to-sql/evaluation.png)

## Ingesting data

We start by ingesting the WikiTableQuestions dataset, which comes as CSV files, into a SQLite database. This database serves as the source of truth for our Text-to-SQL pipeline.

```
import asyncio
import fnmatch
import os
import re
import zipfile

import flyte
import pandas as pd
import requests
from flyte.io import Dir, File
from llama_index.core.llms import ChatMessage
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.llms.openai import OpenAI
from pydantic import BaseModel, Field
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from utils import env

# {{docs-fragment table_info}}
class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(..., description="table name (underscores only, no spaces)")
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )

# {{/docs-fragment table_info}}

@env.task
async def download_and_extract(zip_path: str, search_glob: str) -> Dir:
    """Download and extract the dataset zip file if not already available."""
    output_zip = "data.zip"
    extract_dir = "wiki_table_questions"

    if not os.path.exists(zip_path):
        response = requests.get(zip_path, stream=True)
        with open(output_zip, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
    else:
        output_zip = zip_path
        print(f"Using existing file {output_zip}")

    os.makedirs(extract_dir, exist_ok=True)
    with zipfile.ZipFile(output_zip, "r") as zip_ref:
        for member in zip_ref.namelist():
            if fnmatch.fnmatch(member, search_glob):
                zip_ref.extract(member, extract_dir)

    remote_dir = await Dir.from_local(extract_dir)
    return remote_dir

async def read_csv_file(
    csv_file: File, nrows: int | None = None
) -> pd.DataFrame | None:
    """Safely download and parse a CSV file into a DataFrame."""
    try:
        local_csv_file = await csv_file.download()
        return pd.read_csv(local_csv_file, nrows=nrows)
    except Exception as e:
        print(f"Error parsing {csv_file.path}: {e}")
        return None

def sanitize_column_name(col_name: str) -> str:
    """Sanitize column names by replacing spaces/special chars with underscores."""
    return re.sub(r"\W+", "_", col_name)

async def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    """Create a SQL table from a Pandas DataFrame."""
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Define table columns based on DataFrame dtypes
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    table = Table(table_name, metadata_obj, *columns)

    # Create table in database
    metadata_obj.create_all(engine)

    # Insert data into table
    with engine.begin() as conn:
        for _, row in df.iterrows():
            conn.execute(table.insert().values(**row.to_dict()))

@flyte.trace
async def create_table(
    csv_file: File, table_info: TableInfo, database_path: str
) -> str:
    """Safely create a table from CSV if parsing succeeds."""
    df = await read_csv_file(csv_file)
    if df is None:
        return "false"

    print(f"Creating table: {table_info.table_name}")

    engine = create_engine(f"sqlite:///{database_path}")
    metadata_obj = MetaData()

    await create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
    return "true"

@flyte.trace
async def llm_structured_predict(
    df_str: str,
    table_names: list[str],
    prompt_tmpl: ChatPromptTemplate,
    feedback: str,
    llm: OpenAI,
) -> TableInfo:
    return llm.structured_predict(
        TableInfo,
        prompt_tmpl,
        feedback=feedback,
        table_str=df_str,
        exclude_table_name_list=str(list(table_names)),
    )

async def generate_unique_table_info(
    df_str: str,
    table_names: list[str],
    prompt_tmpl: ChatPromptTemplate,
    llm: OpenAI,
    tablename_lock: asyncio.Lock,
    retries: int = 3,
) -> TableInfo | None:
    """Process a single CSV file to generate a unique TableInfo."""
    last_table_name = None
    for attempt in range(retries):
        feedback = ""
        if attempt > 0:
            feedback = f"Note: '{last_table_name}' already exists. Please pick a new name not in {table_names}."

        table_info = await llm_structured_predict(
            df_str, table_names, prompt_tmpl, feedback, llm
        )
        last_table_name = table_info.table_name

        async with tablename_lock:
            if table_info.table_name not in table_names:
                table_names.append(table_info.table_name)
                return table_info

        print(f"Table name {table_info.table_name} already exists, retrying...")

    return None

async def process_csv_file(
    csv_file: File,
    table_names: list[str],
    semaphore: asyncio.Semaphore,
    tablename_lock: asyncio.Lock,
    llm: OpenAI,
    prompt_tmpl: ChatPromptTemplate,
) -> TableInfo | None:
    """Process a single CSV file to generate a unique TableInfo."""
    async with semaphore:
        df = await read_csv_file(csv_file, nrows=10)
        if df is None:
            return None
        return await generate_unique_table_info(
            df.to_csv(), table_names, prompt_tmpl, llm, tablename_lock
        )

@env.task
async def extract_table_info(
    data_dir: Dir, model: str, concurrency: int
) -> list[TableInfo | None]:
    """Extract structured table information from CSV files."""
    table_names: list[str] = []
    semaphore = asyncio.Semaphore(concurrency)
    tablename_lock = asyncio.Lock()
    llm = OpenAI(model=model)

    prompt_str = """\
    Provide a JSON object with the following fields:

    - `table_name`: must be unique and descriptive (underscores only, no generic names).
    - `table_summary`: short and concise summary of the table.

    Do NOT use any of these table names: {exclude_table_name_list}

    Table:
    {table_str}

    {feedback}
    """
    prompt_tmpl = ChatPromptTemplate(
        message_templates=[ChatMessage.from_str(prompt_str, role="user")]
    )

    tasks = [
        process_csv_file(
            csv_file, table_names, semaphore, tablename_lock, llm, prompt_tmpl
        )
        async for csv_file in data_dir.walk()
    ]

    return await asyncio.gather(*tasks)

# {{docs-fragment data_ingestion}}
@env.task
async def data_ingestion(
    csv_zip_path: str = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
    search_glob: str = "WikiTableQuestions/csv/200-csv/*.csv",
    concurrency: int = 5,
    model: str = "gpt-4o-mini",
) -> tuple[File, list[TableInfo | None]]:
    """Main data ingestion pipeline: download → extract → analyze → create DB."""
    data_dir = await download_and_extract(csv_zip_path, search_glob)
    table_infos = await extract_table_info(data_dir, model, concurrency)

    database_path = "wiki_table_questions.db"

    i = 0
    async for csv_file in data_dir.walk():
        table_info = table_infos[i]
        if table_info:
            ok = await create_table(csv_file, table_info, database_path)
            if ok == "false":
                table_infos[i] = None
        else:
            print(f"Skipping table creation for {csv_file} due to missing TableInfo.")
        i += 1

    db_file = await File.from_local(database_path)
    return db_file, table_infos

# {{/docs-fragment data_ingestion}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/data_ingestion.py*

The ingestion step:

1. Downloads the dataset (a zip archive from GitHub).
2. Extracts the CSV files locally.
3. Generates table metadata (names and descriptions).
4. Creates corresponding tables in SQLite.

The Flyte task returns both the path to the database and the generated table metadata.

```
import asyncio
import fnmatch
import os
import re
import zipfile

import flyte
import pandas as pd
import requests
from flyte.io import Dir, File
from llama_index.core.llms import ChatMessage
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.llms.openai import OpenAI
from pydantic import BaseModel, Field
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from utils import env

# {{docs-fragment table_info}}
class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(..., description="table name (underscores only, no spaces)")
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )

# {{/docs-fragment table_info}}

@env.task
async def download_and_extract(zip_path: str, search_glob: str) -> Dir:
    """Download and extract the dataset zip file if not already available."""
    output_zip = "data.zip"
    extract_dir = "wiki_table_questions"

    if not os.path.exists(zip_path):
        response = requests.get(zip_path, stream=True)
        with open(output_zip, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
    else:
        output_zip = zip_path
        print(f"Using existing file {output_zip}")

    os.makedirs(extract_dir, exist_ok=True)
    with zipfile.ZipFile(output_zip, "r") as zip_ref:
        for member in zip_ref.namelist():
            if fnmatch.fnmatch(member, search_glob):
                zip_ref.extract(member, extract_dir)

    remote_dir = await Dir.from_local(extract_dir)
    return remote_dir

async def read_csv_file(
    csv_file: File, nrows: int | None = None
) -> pd.DataFrame | None:
    """Safely download and parse a CSV file into a DataFrame."""
    try:
        local_csv_file = await csv_file.download()
        return pd.read_csv(local_csv_file, nrows=nrows)
    except Exception as e:
        print(f"Error parsing {csv_file.path}: {e}")
        return None

def sanitize_column_name(col_name: str) -> str:
    """Sanitize column names by replacing spaces/special chars with underscores."""
    return re.sub(r"\W+", "_", col_name)

async def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    """Create a SQL table from a Pandas DataFrame."""
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Define table columns based on DataFrame dtypes
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    table = Table(table_name, metadata_obj, *columns)

    # Create table in database
    metadata_obj.create_all(engine)

    # Insert data into table
    with engine.begin() as conn:
        for _, row in df.iterrows():
            conn.execute(table.insert().values(**row.to_dict()))

@flyte.trace
async def create_table(
    csv_file: File, table_info: TableInfo, database_path: str
) -> str:
    """Safely create a table from CSV if parsing succeeds."""
    df = await read_csv_file(csv_file)
    if df is None:
        return "false"

    print(f"Creating table: {table_info.table_name}")

    engine = create_engine(f"sqlite:///{database_path}")
    metadata_obj = MetaData()

    await create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
    return "true"

@flyte.trace
async def llm_structured_predict(
    df_str: str,
    table_names: list[str],
    prompt_tmpl: ChatPromptTemplate,
    feedback: str,
    llm: OpenAI,
) -> TableInfo:
    return llm.structured_predict(
        TableInfo,
        prompt_tmpl,
        feedback=feedback,
        table_str=df_str,
        exclude_table_name_list=str(list(table_names)),
    )

async def generate_unique_table_info(
    df_str: str,
    table_names: list[str],
    prompt_tmpl: ChatPromptTemplate,
    llm: OpenAI,
    tablename_lock: asyncio.Lock,
    retries: int = 3,
) -> TableInfo | None:
    """Process a single CSV file to generate a unique TableInfo."""
    last_table_name = None
    for attempt in range(retries):
        feedback = ""
        if attempt > 0:
            feedback = f"Note: '{last_table_name}' already exists. Please pick a new name not in {table_names}."

        table_info = await llm_structured_predict(
            df_str, table_names, prompt_tmpl, feedback, llm
        )
        last_table_name = table_info.table_name

        async with tablename_lock:
            if table_info.table_name not in table_names:
                table_names.append(table_info.table_name)
                return table_info

        print(f"Table name {table_info.table_name} already exists, retrying...")

    return None

async def process_csv_file(
    csv_file: File,
    table_names: list[str],
    semaphore: asyncio.Semaphore,
    tablename_lock: asyncio.Lock,
    llm: OpenAI,
    prompt_tmpl: ChatPromptTemplate,
) -> TableInfo | None:
    """Process a single CSV file to generate a unique TableInfo."""
    async with semaphore:
        df = await read_csv_file(csv_file, nrows=10)
        if df is None:
            return None
        return await generate_unique_table_info(
            df.to_csv(), table_names, prompt_tmpl, llm, tablename_lock
        )

@env.task
async def extract_table_info(
    data_dir: Dir, model: str, concurrency: int
) -> list[TableInfo | None]:
    """Extract structured table information from CSV files."""
    table_names: list[str] = []
    semaphore = asyncio.Semaphore(concurrency)
    tablename_lock = asyncio.Lock()
    llm = OpenAI(model=model)

    prompt_str = """\
    Provide a JSON object with the following fields:

    - `table_name`: must be unique and descriptive (underscores only, no generic names).
    - `table_summary`: short and concise summary of the table.

    Do NOT use any of these table names: {exclude_table_name_list}

    Table:
    {table_str}

    {feedback}
    """
    prompt_tmpl = ChatPromptTemplate(
        message_templates=[ChatMessage.from_str(prompt_str, role="user")]
    )

    tasks = [
        process_csv_file(
            csv_file, table_names, semaphore, tablename_lock, llm, prompt_tmpl
        )
        async for csv_file in data_dir.walk()
    ]

    return await asyncio.gather(*tasks)

# {{docs-fragment data_ingestion}}
@env.task
async def data_ingestion(
    csv_zip_path: str = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
    search_glob: str = "WikiTableQuestions/csv/200-csv/*.csv",
    concurrency: int = 5,
    model: str = "gpt-4o-mini",
) -> tuple[File, list[TableInfo | None]]:
    """Main data ingestion pipeline: download → extract → analyze → create DB."""
    data_dir = await download_and_extract(csv_zip_path, search_glob)
    table_infos = await extract_table_info(data_dir, model, concurrency)

    database_path = "wiki_table_questions.db"

    i = 0
    async for csv_file in data_dir.walk():
        table_info = table_infos[i]
        if table_info:
            ok = await create_table(csv_file, table_info, database_path)
            if ok == "false":
                table_infos[i] = None
        else:
            print(f"Skipping table creation for {csv_file} due to missing TableInfo.")
        i += 1

    db_file = await File.from_local(database_path)
    return db_file, table_infos

# {{/docs-fragment data_ingestion}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/data_ingestion.py*

With Union artifacts (coming soon!), you'll be able to persist the ingested SQLite database as an artifact. This removes the need to rerun data ingestion in every pipeline.

## From question to SQL

Next, we define a workflow that converts natural language into executable SQL using a retrieval-augmented generation (RAG) approach.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "sqlalchemy>=2.0.0",
#    "pandas>=2.0.0",
#    "requests>=2.25.0",
#    "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///

import asyncio
from pathlib import Path

import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
    PromptTemplate,
    SQLDatabase,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env

# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
    """Index a single table into vector store."""
    path = f"{table_index_dir}/{table_name}"
    engine = create_engine(database_uri)

    def _fetch_rows():
        with engine.connect() as conn:
            cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
            return cursor.fetchall()

    result = await asyncio.to_thread(_fetch_rows)
    nodes = [TextNode(text=str(tuple(row))) for row in result]
    index = VectorStoreIndex(nodes)
    index.set_index_id("vector_index")
    index.storage_context.persist(path)

    return path

@env.task
async def index_all_tables(db_file: File) -> Dir:
    """Index all tables concurrently."""
    table_index_dir = "table_indices"
    Path(table_index_dir).mkdir(exist_ok=True)

    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    tasks = [
        index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
        for t in sql_database.get_usable_table_names()
    ]
    await asyncio.gather(*tasks)

    remote_dir = await Dir.from_local(table_index_dir)
    return remote_dir

# {{/docs-fragment index_tables}}

@flyte.trace
async def get_table_schema_context(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
) -> str:
    """Retrieve schema + optional description context for a single table."""
    engine = create_engine(database_uri)
    sql_database = SQLDatabase(engine)

    table_info = sql_database.get_single_table_info(table_schema_obj.table_name)

    if table_schema_obj.context_str:
        table_info += f" The table description is: {table_schema_obj.context_str}"

    return table_info

@flyte.trace
async def get_table_row_context(
    table_schema_obj: SQLTableSchema,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Retrieve row-level context examples using vector search."""
    storage_context = StorageContext.from_defaults(
        persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
    )
    vector_index = load_index_from_storage(storage_context, index_id="vector_index")
    vector_retriever = vector_index.as_retriever(similarity_top_k=2)
    relevant_nodes = vector_retriever.retrieve(query)

    if not relevant_nodes:
        return ""

    row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
    for node in relevant_nodes:
        row_context += str(node.get_content()) + "\n"

    return row_context

async def process_table(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Combine schema + row context for one table."""
    table_info = await get_table_schema_context(table_schema_obj, database_uri)
    row_context = await get_table_row_context(
        table_schema_obj, local_vector_index_dir, query
    )

    full_context = table_info
    if row_context:
        full_context += "\n" + row_context

    print(f"Table Info: {full_context}")
    return full_context

async def get_table_context_and_rows_str(
    query: str,
    database_uri: str,
    table_schema_objs: list[SQLTableSchema],
    vector_index_dir: Dir,
):
    """Get combined schema + row context for all tables."""
    local_vector_index_dir = await vector_index_dir.download()

    # run per-table work concurrently
    context_strs = await asyncio.gather(
        *[
            process_table(t, database_uri, local_vector_index_dir, query)
            for t in table_schema_objs
        ]
    )

    return "\n\n".join(context_strs)

# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
    query: str,
    table_infos: list[TableInfo | None],
    db_file: File,
    vector_index_dir: Dir,
) -> str:
    """Retrieve relevant tables and return schema context string."""
    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    table_node_mapping = SQLTableNodeMapping(sql_database)
    table_schema_objs = [
        SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
        for t in table_infos
        if t is not None
    ]

    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
    )
    obj_retriever = obj_index.as_retriever(similarity_top_k=3)

    retrieved_schemas = obj_retriever.retrieve(query)
    return await get_table_context_and_rows_str(
        query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
    )

# {{/docs-fragment retrieve_tables}}

def parse_response_to_sql(chat_response: ChatResponse) -> str:
    """Extract SQL query from LLM response."""
    response = chat_response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()

# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
    """Generate SQL query from natural language question and table context."""
    llm = OpenAI(model=model)

    fmt_messages = (
        PromptTemplate(
            prompt,
            prompt_type=PromptType.TEXT_TO_SQL,
        )
        .partial_format(dialect="sqlite")
        .format_messages(query_str=query, schema=table_context)
    )

    chat_response = await llm.achat(fmt_messages)
    return parse_response_to_sql(chat_response)

@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
    """Run SQL query on database and synthesize final response."""
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    response_synthesis_prompt = PromptTemplate(
        "Given an input question, synthesize a response from the query results.\n"
        "Query: {query_str}\n"
        "SQL: {sql_query}\n"
        "SQL Response: {context_str}\n"
        "Response: "
    )

    llm = OpenAI(model=model)
    fmt_messages = response_synthesis_prompt.format_messages(
        sql_query=sql,
        context_str=str(retrieved_rows),
        query_str=query,
    )
    chat_response = await llm.achat(fmt_messages)
    return chat_response.message.content

# {{/docs-fragment sql_and_response}}

# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
    system_prompt: str = (
        "Given an input question, first create a syntactically correct {dialect} "
        "query to run, then look at the results of the query and return the answer. "
        "You can order the results by a relevant column to return the most "
        "interesting examples in the database.\n\n"
        "Never query for all the columns from a specific table, only ask for a "
        "few relevant columns given the question.\n\n"
        "Pay attention to use only the column names that you can see in the schema "
        "description. "
        "Be careful to not query for columns that do not exist. "
        "Pay attention to which column is in which table. "
        "Also, qualify column names with the table name when needed. "
        "You are required to use the following format, each taking one line:\n\n"
        "Question: Question here\n"
        "SQLQuery: SQL Query to run\n"
        "SQLResult: Result of the SQLQuery\n"
        "Answer: Final answer here\n\n"
        "Only use tables listed below.\n"
        "{schema}\n\n"
        "Question: {query_str}\n"
        "SQLQuery: "
    ),
    query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
    model: str = "gpt-4o-mini",
) -> str:
    db_file, table_infos = await data_ingestion()
    vector_index_dir = await index_all_tables(db_file)
    table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
    sql = await generate_sql(query, table_context, model, system_prompt)
    return await generate_response(query, sql, db_file, model)

# {{/docs-fragment text_to_sql}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(text_to_sql)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py*

The main `text_to_sql` task orchestrates the pipeline:

- Ingest data
- Build vector indices for each table
- Retrieve relevant tables and rows
- Generate SQL queries with an LLM
- Execute queries and synthesize answers

We use OpenAI GPT models with carefully structured prompts to maximize SQL correctness.

### Vector indexing

We index each table's rows semantically so the model can retrieve relevant examples during SQL generation.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "sqlalchemy>=2.0.0",
#    "pandas>=2.0.0",
#    "requests>=2.25.0",
#    "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///

import asyncio
from pathlib import Path

import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
    PromptTemplate,
    SQLDatabase,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env

# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
    """Index a single table into vector store."""
    path = f"{table_index_dir}/{table_name}"
    engine = create_engine(database_uri)

    def _fetch_rows():
        with engine.connect() as conn:
            cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
            return cursor.fetchall()

    result = await asyncio.to_thread(_fetch_rows)
    nodes = [TextNode(text=str(tuple(row))) for row in result]
    index = VectorStoreIndex(nodes)
    index.set_index_id("vector_index")
    index.storage_context.persist(path)

    return path

@env.task
async def index_all_tables(db_file: File) -> Dir:
    """Index all tables concurrently."""
    table_index_dir = "table_indices"
    Path(table_index_dir).mkdir(exist_ok=True)

    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    tasks = [
        index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
        for t in sql_database.get_usable_table_names()
    ]
    await asyncio.gather(*tasks)

    remote_dir = await Dir.from_local(table_index_dir)
    return remote_dir

# {{/docs-fragment index_tables}}

@flyte.trace
async def get_table_schema_context(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
) -> str:
    """Retrieve schema + optional description context for a single table."""
    engine = create_engine(database_uri)
    sql_database = SQLDatabase(engine)

    table_info = sql_database.get_single_table_info(table_schema_obj.table_name)

    if table_schema_obj.context_str:
        table_info += f" The table description is: {table_schema_obj.context_str}"

    return table_info

@flyte.trace
async def get_table_row_context(
    table_schema_obj: SQLTableSchema,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Retrieve row-level context examples using vector search."""
    storage_context = StorageContext.from_defaults(
        persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
    )
    vector_index = load_index_from_storage(storage_context, index_id="vector_index")
    vector_retriever = vector_index.as_retriever(similarity_top_k=2)
    relevant_nodes = vector_retriever.retrieve(query)

    if not relevant_nodes:
        return ""

    row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
    for node in relevant_nodes:
        row_context += str(node.get_content()) + "\n"

    return row_context

async def process_table(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Combine schema + row context for one table."""
    table_info = await get_table_schema_context(table_schema_obj, database_uri)
    row_context = await get_table_row_context(
        table_schema_obj, local_vector_index_dir, query
    )

    full_context = table_info
    if row_context:
        full_context += "\n" + row_context

    print(f"Table Info: {full_context}")
    return full_context

async def get_table_context_and_rows_str(
    query: str,
    database_uri: str,
    table_schema_objs: list[SQLTableSchema],
    vector_index_dir: Dir,
):
    """Get combined schema + row context for all tables."""
    local_vector_index_dir = await vector_index_dir.download()

    # run per-table work concurrently
    context_strs = await asyncio.gather(
        *[
            process_table(t, database_uri, local_vector_index_dir, query)
            for t in table_schema_objs
        ]
    )

    return "\n\n".join(context_strs)

# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
    query: str,
    table_infos: list[TableInfo | None],
    db_file: File,
    vector_index_dir: Dir,
) -> str:
    """Retrieve relevant tables and return schema context string."""
    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    table_node_mapping = SQLTableNodeMapping(sql_database)
    table_schema_objs = [
        SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
        for t in table_infos
        if t is not None
    ]

    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
    )
    obj_retriever = obj_index.as_retriever(similarity_top_k=3)

    retrieved_schemas = obj_retriever.retrieve(query)
    return await get_table_context_and_rows_str(
        query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
    )

# {{/docs-fragment retrieve_tables}}

def parse_response_to_sql(chat_response: ChatResponse) -> str:
    """Extract SQL query from LLM response."""
    response = chat_response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()

# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
    """Generate SQL query from natural language question and table context."""
    llm = OpenAI(model=model)

    fmt_messages = (
        PromptTemplate(
            prompt,
            prompt_type=PromptType.TEXT_TO_SQL,
        )
        .partial_format(dialect="sqlite")
        .format_messages(query_str=query, schema=table_context)
    )

    chat_response = await llm.achat(fmt_messages)
    return parse_response_to_sql(chat_response)

@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
    """Run SQL query on database and synthesize final response."""
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    response_synthesis_prompt = PromptTemplate(
        "Given an input question, synthesize a response from the query results.\n"
        "Query: {query_str}\n"
        "SQL: {sql_query}\n"
        "SQL Response: {context_str}\n"
        "Response: "
    )

    llm = OpenAI(model=model)
    fmt_messages = response_synthesis_prompt.format_messages(
        sql_query=sql,
        context_str=str(retrieved_rows),
        query_str=query,
    )
    chat_response = await llm.achat(fmt_messages)
    return chat_response.message.content

# {{/docs-fragment sql_and_response}}

# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
    system_prompt: str = (
        "Given an input question, first create a syntactically correct {dialect} "
        "query to run, then look at the results of the query and return the answer. "
        "You can order the results by a relevant column to return the most "
        "interesting examples in the database.\n\n"
        "Never query for all the columns from a specific table, only ask for a "
        "few relevant columns given the question.\n\n"
        "Pay attention to use only the column names that you can see in the schema "
        "description. "
        "Be careful to not query for columns that do not exist. "
        "Pay attention to which column is in which table. "
        "Also, qualify column names with the table name when needed. "
        "You are required to use the following format, each taking one line:\n\n"
        "Question: Question here\n"
        "SQLQuery: SQL Query to run\n"
        "SQLResult: Result of the SQLQuery\n"
        "Answer: Final answer here\n\n"
        "Only use tables listed below.\n"
        "{schema}\n\n"
        "Question: {query_str}\n"
        "SQLQuery: "
    ),
    query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
    model: str = "gpt-4o-mini",
) -> str:
    db_file, table_infos = await data_ingestion()
    vector_index_dir = await index_all_tables(db_file)
    table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
    sql = await generate_sql(query, table_context, model, system_prompt)
    return await generate_response(query, sql, db_file, model)

# {{/docs-fragment text_to_sql}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(text_to_sql)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py*

Each row becomes a text node stored in LlamaIndex’s `VectorStoreIndex`. This lets the system pull semantically similar rows when handling queries.

### Table retrieval and context building

We then retrieve the most relevant tables for a given query and build rich context that combines schema information with sample rows.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "sqlalchemy>=2.0.0",
#    "pandas>=2.0.0",
#    "requests>=2.25.0",
#    "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///

import asyncio
from pathlib import Path

import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
    PromptTemplate,
    SQLDatabase,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env

# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
    """Index a single table into vector store."""
    path = f"{table_index_dir}/{table_name}"
    engine = create_engine(database_uri)

    def _fetch_rows():
        with engine.connect() as conn:
            cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
            return cursor.fetchall()

    result = await asyncio.to_thread(_fetch_rows)
    nodes = [TextNode(text=str(tuple(row))) for row in result]
    index = VectorStoreIndex(nodes)
    index.set_index_id("vector_index")
    index.storage_context.persist(path)

    return path

@env.task
async def index_all_tables(db_file: File) -> Dir:
    """Index all tables concurrently."""
    table_index_dir = "table_indices"
    Path(table_index_dir).mkdir(exist_ok=True)

    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    tasks = [
        index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
        for t in sql_database.get_usable_table_names()
    ]
    await asyncio.gather(*tasks)

    remote_dir = await Dir.from_local(table_index_dir)
    return remote_dir

# {{/docs-fragment index_tables}}

@flyte.trace
async def get_table_schema_context(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
) -> str:
    """Retrieve schema + optional description context for a single table."""
    engine = create_engine(database_uri)
    sql_database = SQLDatabase(engine)

    table_info = sql_database.get_single_table_info(table_schema_obj.table_name)

    if table_schema_obj.context_str:
        table_info += f" The table description is: {table_schema_obj.context_str}"

    return table_info

@flyte.trace
async def get_table_row_context(
    table_schema_obj: SQLTableSchema,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Retrieve row-level context examples using vector search."""
    storage_context = StorageContext.from_defaults(
        persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
    )
    vector_index = load_index_from_storage(storage_context, index_id="vector_index")
    vector_retriever = vector_index.as_retriever(similarity_top_k=2)
    relevant_nodes = vector_retriever.retrieve(query)

    if not relevant_nodes:
        return ""

    row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
    for node in relevant_nodes:
        row_context += str(node.get_content()) + "\n"

    return row_context

async def process_table(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Combine schema + row context for one table."""
    table_info = await get_table_schema_context(table_schema_obj, database_uri)
    row_context = await get_table_row_context(
        table_schema_obj, local_vector_index_dir, query
    )

    full_context = table_info
    if row_context:
        full_context += "\n" + row_context

    print(f"Table Info: {full_context}")
    return full_context

async def get_table_context_and_rows_str(
    query: str,
    database_uri: str,
    table_schema_objs: list[SQLTableSchema],
    vector_index_dir: Dir,
):
    """Get combined schema + row context for all tables."""
    local_vector_index_dir = await vector_index_dir.download()

    # run per-table work concurrently
    context_strs = await asyncio.gather(
        *[
            process_table(t, database_uri, local_vector_index_dir, query)
            for t in table_schema_objs
        ]
    )

    return "\n\n".join(context_strs)

# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
    query: str,
    table_infos: list[TableInfo | None],
    db_file: File,
    vector_index_dir: Dir,
) -> str:
    """Retrieve relevant tables and return schema context string."""
    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    table_node_mapping = SQLTableNodeMapping(sql_database)
    table_schema_objs = [
        SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
        for t in table_infos
        if t is not None
    ]

    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
    )
    obj_retriever = obj_index.as_retriever(similarity_top_k=3)

    retrieved_schemas = obj_retriever.retrieve(query)
    return await get_table_context_and_rows_str(
        query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
    )

# {{/docs-fragment retrieve_tables}}

def parse_response_to_sql(chat_response: ChatResponse) -> str:
    """Extract SQL query from LLM response."""
    response = chat_response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()

# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
    """Generate SQL query from natural language question and table context."""
    llm = OpenAI(model=model)

    fmt_messages = (
        PromptTemplate(
            prompt,
            prompt_type=PromptType.TEXT_TO_SQL,
        )
        .partial_format(dialect="sqlite")
        .format_messages(query_str=query, schema=table_context)
    )

    chat_response = await llm.achat(fmt_messages)
    return parse_response_to_sql(chat_response)

@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
    """Run SQL query on database and synthesize final response."""
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    response_synthesis_prompt = PromptTemplate(
        "Given an input question, synthesize a response from the query results.\n"
        "Query: {query_str}\n"
        "SQL: {sql_query}\n"
        "SQL Response: {context_str}\n"
        "Response: "
    )

    llm = OpenAI(model=model)
    fmt_messages = response_synthesis_prompt.format_messages(
        sql_query=sql,
        context_str=str(retrieved_rows),
        query_str=query,
    )
    chat_response = await llm.achat(fmt_messages)
    return chat_response.message.content

# {{/docs-fragment sql_and_response}}

# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
    system_prompt: str = (
        "Given an input question, first create a syntactically correct {dialect} "
        "query to run, then look at the results of the query and return the answer. "
        "You can order the results by a relevant column to return the most "
        "interesting examples in the database.\n\n"
        "Never query for all the columns from a specific table, only ask for a "
        "few relevant columns given the question.\n\n"
        "Pay attention to use only the column names that you can see in the schema "
        "description. "
        "Be careful to not query for columns that do not exist. "
        "Pay attention to which column is in which table. "
        "Also, qualify column names with the table name when needed. "
        "You are required to use the following format, each taking one line:\n\n"
        "Question: Question here\n"
        "SQLQuery: SQL Query to run\n"
        "SQLResult: Result of the SQLQuery\n"
        "Answer: Final answer here\n\n"
        "Only use tables listed below.\n"
        "{schema}\n\n"
        "Question: {query_str}\n"
        "SQLQuery: "
    ),
    query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
    model: str = "gpt-4o-mini",
) -> str:
    db_file, table_infos = await data_ingestion()
    vector_index_dir = await index_all_tables(db_file)
    table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
    sql = await generate_sql(query, table_context, model, system_prompt)
    return await generate_response(query, sql, db_file, model)

# {{/docs-fragment text_to_sql}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(text_to_sql)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py*

The retriever selects tables via semantic similarity, then attaches their schema and example rows. This context grounds the model's SQL generation in the database's actual structure and content.

### SQL generation and response synthesis

Finally, we generate SQL queries and produce natural language answers.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "sqlalchemy>=2.0.0",
#    "pandas>=2.0.0",
#    "requests>=2.25.0",
#    "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///

import asyncio
from pathlib import Path

import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
    PromptTemplate,
    SQLDatabase,
    StorageContext,
    VectorStoreIndex,
    load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env

# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
    """Index a single table into vector store."""
    path = f"{table_index_dir}/{table_name}"
    engine = create_engine(database_uri)

    def _fetch_rows():
        with engine.connect() as conn:
            cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
            return cursor.fetchall()

    result = await asyncio.to_thread(_fetch_rows)
    nodes = [TextNode(text=str(tuple(row))) for row in result]
    index = VectorStoreIndex(nodes)
    index.set_index_id("vector_index")
    index.storage_context.persist(path)

    return path

@env.task
async def index_all_tables(db_file: File) -> Dir:
    """Index all tables concurrently."""
    table_index_dir = "table_indices"
    Path(table_index_dir).mkdir(exist_ok=True)

    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    tasks = [
        index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
        for t in sql_database.get_usable_table_names()
    ]
    await asyncio.gather(*tasks)

    remote_dir = await Dir.from_local(table_index_dir)
    return remote_dir

# {{/docs-fragment index_tables}}

@flyte.trace
async def get_table_schema_context(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
) -> str:
    """Retrieve schema + optional description context for a single table."""
    engine = create_engine(database_uri)
    sql_database = SQLDatabase(engine)

    table_info = sql_database.get_single_table_info(table_schema_obj.table_name)

    if table_schema_obj.context_str:
        table_info += f" The table description is: {table_schema_obj.context_str}"

    return table_info

@flyte.trace
async def get_table_row_context(
    table_schema_obj: SQLTableSchema,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Retrieve row-level context examples using vector search."""
    storage_context = StorageContext.from_defaults(
        persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
    )
    vector_index = load_index_from_storage(storage_context, index_id="vector_index")
    vector_retriever = vector_index.as_retriever(similarity_top_k=2)
    relevant_nodes = vector_retriever.retrieve(query)

    if not relevant_nodes:
        return ""

    row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
    for node in relevant_nodes:
        row_context += str(node.get_content()) + "\n"

    return row_context

async def process_table(
    table_schema_obj: SQLTableSchema,
    database_uri: str,
    local_vector_index_dir: str,
    query: str,
) -> str:
    """Combine schema + row context for one table."""
    table_info = await get_table_schema_context(table_schema_obj, database_uri)
    row_context = await get_table_row_context(
        table_schema_obj, local_vector_index_dir, query
    )

    full_context = table_info
    if row_context:
        full_context += "\n" + row_context

    print(f"Table Info: {full_context}")
    return full_context

async def get_table_context_and_rows_str(
    query: str,
    database_uri: str,
    table_schema_objs: list[SQLTableSchema],
    vector_index_dir: Dir,
):
    """Get combined schema + row context for all tables."""
    local_vector_index_dir = await vector_index_dir.download()

    # run per-table work concurrently
    context_strs = await asyncio.gather(
        *[
            process_table(t, database_uri, local_vector_index_dir, query)
            for t in table_schema_objs
        ]
    )

    return "\n\n".join(context_strs)

# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
    query: str,
    table_infos: list[TableInfo | None],
    db_file: File,
    vector_index_dir: Dir,
) -> str:
    """Retrieve relevant tables and return schema context string."""
    await db_file.download(local_path="local_db.sqlite")
    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)

    table_node_mapping = SQLTableNodeMapping(sql_database)
    table_schema_objs = [
        SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
        for t in table_infos
        if t is not None
    ]

    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
    )
    obj_retriever = obj_index.as_retriever(similarity_top_k=3)

    retrieved_schemas = obj_retriever.retrieve(query)
    return await get_table_context_and_rows_str(
        query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
    )

# {{/docs-fragment retrieve_tables}}

def parse_response_to_sql(chat_response: ChatResponse) -> str:
    """Extract SQL query from LLM response."""
    response = chat_response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()

# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
    """Generate SQL query from natural language question and table context."""
    llm = OpenAI(model=model)

    fmt_messages = (
        PromptTemplate(
            prompt,
            prompt_type=PromptType.TEXT_TO_SQL,
        )
        .partial_format(dialect="sqlite")
        .format_messages(query_str=query, schema=table_context)
    )

    chat_response = await llm.achat(fmt_messages)
    return parse_response_to_sql(chat_response)

@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
    """Run SQL query on database and synthesize final response."""
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    response_synthesis_prompt = PromptTemplate(
        "Given an input question, synthesize a response from the query results.\n"
        "Query: {query_str}\n"
        "SQL: {sql_query}\n"
        "SQL Response: {context_str}\n"
        "Response: "
    )

    llm = OpenAI(model=model)
    fmt_messages = response_synthesis_prompt.format_messages(
        sql_query=sql,
        context_str=str(retrieved_rows),
        query_str=query,
    )
    chat_response = await llm.achat(fmt_messages)
    return chat_response.message.content

# {{/docs-fragment sql_and_response}}

# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
    system_prompt: str = (
        "Given an input question, first create a syntactically correct {dialect} "
        "query to run, then look at the results of the query and return the answer. "
        "You can order the results by a relevant column to return the most "
        "interesting examples in the database.\n\n"
        "Never query for all the columns from a specific table, only ask for a "
        "few relevant columns given the question.\n\n"
        "Pay attention to use only the column names that you can see in the schema "
        "description. "
        "Be careful to not query for columns that do not exist. "
        "Pay attention to which column is in which table. "
        "Also, qualify column names with the table name when needed. "
        "You are required to use the following format, each taking one line:\n\n"
        "Question: Question here\n"
        "SQLQuery: SQL Query to run\n"
        "SQLResult: Result of the SQLQuery\n"
        "Answer: Final answer here\n\n"
        "Only use tables listed below.\n"
        "{schema}\n\n"
        "Question: {query_str}\n"
        "SQLQuery: "
    ),
    query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
    model: str = "gpt-4o-mini",
) -> str:
    db_file, table_infos = await data_ingestion()
    vector_index_dir = await index_all_tables(db_file)
    table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
    sql = await generate_sql(query, table_context, model, system_prompt)
    return await generate_response(query, sql, db_file, model)

# {{/docs-fragment text_to_sql}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(text_to_sql)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py*

The SQL generation prompt includes schema, example rows, and formatting rules. After execution, the system returns a final answer.

At this point, we have an end-to-end Text-to-SQL pipeline: natural language questions go in, SQL queries run, and answers come back. To make this workflow production-ready, we leveraged several Flyte 2 capabilities. Caching ensures that repeated steps, like table ingestion or vector indexing, don’t need to rerun unnecessarily, saving time and compute. Containerization provides consistent, reproducible execution across environments, making it easier to scale and deploy. Observability features let us track every step of the pipeline, monitor performance, and debug issues quickly.

While the pipeline works end-to-end, to get a pulse on how it performs across multiple prompts and to gradually improve performance, we can start experimenting with prompt tuning.

Two things help make this process meaningful:

- **A clean evaluation dataset** - so we can measure accuracy against trusted ground truth.
- **A systematic evaluation loop** - so we can see whether prompt changes or other adjustments actually help.

With these in place, the next step is to build a "golden" QA dataset that will guide iterative prompt optimization.

## Building the QA dataset

> [!NOTE]
> The WikiTableQuestions dataset already includes question–answer pairs, available in its [GitHub repository](https://github.com/ppasupat/WikiTableQuestions/tree/master/data). To use them for this workflow, you'll need to adapt the data into the required format, but the raw material is there for you to build on.

We generate a dataset of natural language questions paired with executable SQL queries. This dataset acts as the benchmark for prompt tuning and evaluation.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///

import sqlite3

import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel

class QAItem(BaseModel):
    question: str
    sql: str

class QAList(BaseModel):
    items: list[QAItem]

# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
    """
    Download the SQLite DB, extract schema info (columns + sample rows),
    then split it into chunks with up to `tables_per_chunk` tables each.
    """
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    tables = cursor.execute(
        "SELECT name FROM sqlite_master WHERE type='table';"
    ).fetchall()

    schema_blocks = []
    for table in tables:
        table_name = table[0]

        # columns
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = [col[1] for col in cursor.fetchall()]
        block = f"Table: {table_name}({', '.join(columns)})"

        # sample rows
        cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
        rows = cursor.fetchall()
        if rows:
            block += "\nSample rows:\n"
            for row in rows:
                block += f"{row}\n"

        schema_blocks.append(block)

    conn.close()

    chunks = []
    current_chunk = []
    for block in schema_blocks:
        current_chunk.append(block)
        if len(current_chunk) >= tables_per_chunk:
            chunks.append("\n".join(current_chunk))
            current_chunk = []
    if current_chunk:
        chunks.append("\n".join(current_chunk))

    return chunks

# {{/docs-fragment get_and_split_schema}}

# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
    schema: str, num_samples: int, batch_size: int
) -> QAList:
    llm = OpenAI(model="gpt-4.1")

    prompt_tmpl = PromptTemplate(
        """Prompt: You are helping build a Text-to-SQL dataset.

Here is the database schema:
{schema}

Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.

Reasoning process (you must follow this internally):

- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.

Final Output:
Return only a JSON object with one field:

- "items": a list of {num} objects, each with:
    - "question": the natural language question
    - "sql": the corresponding SQL query
"""
    )

    all_items: list[QAItem] = []

    # batch generation
    for start in range(0, num_samples, batch_size):
        current_num = min(batch_size, num_samples - start)
        response = llm.structured_predict(
            QAList,
            prompt_tmpl,
            schema=schema,
            num=current_num,
        )
        all_items.extend(response.items)

    # deduplicate
    seen = set()
    unique_items: list[QAItem] = []
    for item in all_items:
        key = (item.question.strip().lower(), item.sql.strip().lower())
        if key not in seen:
            seen.add(key)
            unique_items.append(item)

    return QAList(items=unique_items[:num_samples])

# {{/docs-fragment generate_questions_and_sql}}

@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
    """Validate a batch of question/sql/result dicts using one LLM call."""
    batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""

    for i, pair in enumerate(pairs, start=1):
        batch_prompt += f"""
Example {i}:
Question:
{pair['question']}

SQL:
{pair['sql']}

Result:
{pair['rows']}
---
"""

    llm = OpenAI(model="gpt-4.1")
    resp = await llm.acomplete(batch_prompt)

    # Expect exactly one True/False per example
    results = [
        line.strip()
        for line in resp.text.splitlines()
        if line.strip() in ("True", "False")
    ]
    return results

# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
    db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    qa_data = []
    batch = []

    for pair in question_sql_pairs.items:
        q, sql = pair.question, pair.sql
        try:
            cursor.execute(sql)
            rows = cursor.fetchall()
            batch.append({"question": q, "sql": sql, "rows": str(rows)})

            # process when batch is full
            if len(batch) == batch_size:
                results = await llm_validate_batch(batch)
                for pair, is_valid in zip(batch, results):
                    if is_valid == "True":
                        qa_data.append(
                            {
                                "input": pair["question"],
                                "sql": pair["sql"],
                                "target": pair["rows"],
                            }
                        )
                    else:
                        print(f"Filtered out incorrect result for: {pair['question']}")
                batch = []
        except Exception as e:
            print(f"Skipping invalid SQL: {sql} ({e})")

    # process leftover batch
    if batch:
        results = await llm_validate_batch(batch)
        for pair, is_valid in zip(batch, results):
            if is_valid == "True":
                qa_data.append(
                    {
                        "input": pair["question"],
                        "sql": pair["sql"],
                        "target": pair["rows"],
                    }
                )
            else:
                print(f"Filtered out incorrect result for: {pair['question']}")

    conn.close()
    return qa_data

# {{/docs-fragment validate_sql}}

@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
    df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])

    csv_file = "qa_dataset.csv"
    df.to_csv(csv_file, index=False)

    return await File.from_local(csv_file)

# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
    num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
    db_file, _ = await data_ingestion()
    schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)

    per_chunk_samples = max(1, num_samples // len(schema_chunks))
    final_qa_data = []

    for chunk in schema_chunks:
        qa_list = await generate_questions_and_sql(
            schema=chunk,
            num_samples=per_chunk_samples,
            batch_size=batch_size,
        )
        qa_data = await validate_sql(db_file, qa_list, batch_size)
        final_qa_data.extend(qa_data)

    csv_file = await save_to_csv(final_qa_data)
    return csv_file

# {{/docs-fragment build_eval_dataset}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(build_eval_dataset)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py*

The pipeline does the following:

- Schema extraction – pull full database schemas, including table names, columns, and sample rows.
- Question–SQL generation – use an LLM to produce natural language questions with matching SQL queries.
- Validation – run each query against the database, filter out invalid results, and also remove results that aren't relevant.
- Final export – store the clean, validated pairs in CSV format for downstream use.

### Schema extraction and chunking

We break schemas into smaller chunks to cover all tables evenly. This avoids overfitting to a subset of tables and ensures broad coverage across the dataset.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///

import sqlite3

import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel

class QAItem(BaseModel):
    question: str
    sql: str

class QAList(BaseModel):
    items: list[QAItem]

# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
    """
    Download the SQLite DB, extract schema info (columns + sample rows),
    then split it into chunks with up to `tables_per_chunk` tables each.
    """
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    tables = cursor.execute(
        "SELECT name FROM sqlite_master WHERE type='table';"
    ).fetchall()

    schema_blocks = []
    for table in tables:
        table_name = table[0]

        # columns
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = [col[1] for col in cursor.fetchall()]
        block = f"Table: {table_name}({', '.join(columns)})"

        # sample rows
        cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
        rows = cursor.fetchall()
        if rows:
            block += "\nSample rows:\n"
            for row in rows:
                block += f"{row}\n"

        schema_blocks.append(block)

    conn.close()

    chunks = []
    current_chunk = []
    for block in schema_blocks:
        current_chunk.append(block)
        if len(current_chunk) >= tables_per_chunk:
            chunks.append("\n".join(current_chunk))
            current_chunk = []
    if current_chunk:
        chunks.append("\n".join(current_chunk))

    return chunks

# {{/docs-fragment get_and_split_schema}}

# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
    schema: str, num_samples: int, batch_size: int
) -> QAList:
    llm = OpenAI(model="gpt-4.1")

    prompt_tmpl = PromptTemplate(
        """Prompt: You are helping build a Text-to-SQL dataset.

Here is the database schema:
{schema}

Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.

Reasoning process (you must follow this internally):

- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.

Final Output:
Return only a JSON object with one field:

- "items": a list of {num} objects, each with:
    - "question": the natural language question
    - "sql": the corresponding SQL query
"""
    )

    all_items: list[QAItem] = []

    # batch generation
    for start in range(0, num_samples, batch_size):
        current_num = min(batch_size, num_samples - start)
        response = llm.structured_predict(
            QAList,
            prompt_tmpl,
            schema=schema,
            num=current_num,
        )
        all_items.extend(response.items)

    # deduplicate
    seen = set()
    unique_items: list[QAItem] = []
    for item in all_items:
        key = (item.question.strip().lower(), item.sql.strip().lower())
        if key not in seen:
            seen.add(key)
            unique_items.append(item)

    return QAList(items=unique_items[:num_samples])

# {{/docs-fragment generate_questions_and_sql}}

@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
    """Validate a batch of question/sql/result dicts using one LLM call."""
    batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""

    for i, pair in enumerate(pairs, start=1):
        batch_prompt += f"""
Example {i}:
Question:
{pair['question']}

SQL:
{pair['sql']}

Result:
{pair['rows']}
---
"""

    llm = OpenAI(model="gpt-4.1")
    resp = await llm.acomplete(batch_prompt)

    # Expect exactly one True/False per example
    results = [
        line.strip()
        for line in resp.text.splitlines()
        if line.strip() in ("True", "False")
    ]
    return results

# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
    db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    qa_data = []
    batch = []

    for pair in question_sql_pairs.items:
        q, sql = pair.question, pair.sql
        try:
            cursor.execute(sql)
            rows = cursor.fetchall()
            batch.append({"question": q, "sql": sql, "rows": str(rows)})

            # process when batch is full
            if len(batch) == batch_size:
                results = await llm_validate_batch(batch)
                for pair, is_valid in zip(batch, results):
                    if is_valid == "True":
                        qa_data.append(
                            {
                                "input": pair["question"],
                                "sql": pair["sql"],
                                "target": pair["rows"],
                            }
                        )
                    else:
                        print(f"Filtered out incorrect result for: {pair['question']}")
                batch = []
        except Exception as e:
            print(f"Skipping invalid SQL: {sql} ({e})")

    # process leftover batch
    if batch:
        results = await llm_validate_batch(batch)
        for pair, is_valid in zip(batch, results):
            if is_valid == "True":
                qa_data.append(
                    {
                        "input": pair["question"],
                        "sql": pair["sql"],
                        "target": pair["rows"],
                    }
                )
            else:
                print(f"Filtered out incorrect result for: {pair['question']}")

    conn.close()
    return qa_data

# {{/docs-fragment validate_sql}}

@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
    df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])

    csv_file = "qa_dataset.csv"
    df.to_csv(csv_file, index=False)

    return await File.from_local(csv_file)

# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
    num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
    db_file, _ = await data_ingestion()
    schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)

    per_chunk_samples = max(1, num_samples // len(schema_chunks))
    final_qa_data = []

    for chunk in schema_chunks:
        qa_list = await generate_questions_and_sql(
            schema=chunk,
            num_samples=per_chunk_samples,
            batch_size=batch_size,
        )
        qa_data = await validate_sql(db_file, qa_list, batch_size)
        final_qa_data.extend(qa_data)

    csv_file = await save_to_csv(final_qa_data)
    return csv_file

# {{/docs-fragment build_eval_dataset}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(build_eval_dataset)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py*

### Question and SQL generation

Using structured prompts, we ask an LLM to generate realistic questions users might ask, then pair them with syntactically valid SQL queries. Deduplication ensures diversity across queries.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///

import sqlite3

import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel

class QAItem(BaseModel):
    question: str
    sql: str

class QAList(BaseModel):
    items: list[QAItem]

# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
    """
    Download the SQLite DB, extract schema info (columns + sample rows),
    then split it into chunks with up to `tables_per_chunk` tables each.
    """
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    tables = cursor.execute(
        "SELECT name FROM sqlite_master WHERE type='table';"
    ).fetchall()

    schema_blocks = []
    for table in tables:
        table_name = table[0]

        # columns
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = [col[1] for col in cursor.fetchall()]
        block = f"Table: {table_name}({', '.join(columns)})"

        # sample rows
        cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
        rows = cursor.fetchall()
        if rows:
            block += "\nSample rows:\n"
            for row in rows:
                block += f"{row}\n"

        schema_blocks.append(block)

    conn.close()

    chunks = []
    current_chunk = []
    for block in schema_blocks:
        current_chunk.append(block)
        if len(current_chunk) >= tables_per_chunk:
            chunks.append("\n".join(current_chunk))
            current_chunk = []
    if current_chunk:
        chunks.append("\n".join(current_chunk))

    return chunks

# {{/docs-fragment get_and_split_schema}}

# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
    schema: str, num_samples: int, batch_size: int
) -> QAList:
    llm = OpenAI(model="gpt-4.1")

    prompt_tmpl = PromptTemplate(
        """Prompt: You are helping build a Text-to-SQL dataset.

Here is the database schema:
{schema}

Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.

Reasoning process (you must follow this internally):

- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.

Final Output:
Return only a JSON object with one field:

- "items": a list of {num} objects, each with:
    - "question": the natural language question
    - "sql": the corresponding SQL query
"""
    )

    all_items: list[QAItem] = []

    # batch generation
    for start in range(0, num_samples, batch_size):
        current_num = min(batch_size, num_samples - start)
        response = llm.structured_predict(
            QAList,
            prompt_tmpl,
            schema=schema,
            num=current_num,
        )
        all_items.extend(response.items)

    # deduplicate
    seen = set()
    unique_items: list[QAItem] = []
    for item in all_items:
        key = (item.question.strip().lower(), item.sql.strip().lower())
        if key not in seen:
            seen.add(key)
            unique_items.append(item)

    return QAList(items=unique_items[:num_samples])

# {{/docs-fragment generate_questions_and_sql}}

@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
    """Validate a batch of question/sql/result dicts using one LLM call."""
    batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""

    for i, pair in enumerate(pairs, start=1):
        batch_prompt += f"""
Example {i}:
Question:
{pair['question']}

SQL:
{pair['sql']}

Result:
{pair['rows']}
---
"""

    llm = OpenAI(model="gpt-4.1")
    resp = await llm.acomplete(batch_prompt)

    # Expect exactly one True/False per example
    results = [
        line.strip()
        for line in resp.text.splitlines()
        if line.strip() in ("True", "False")
    ]
    return results

# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
    db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    qa_data = []
    batch = []

    for pair in question_sql_pairs.items:
        q, sql = pair.question, pair.sql
        try:
            cursor.execute(sql)
            rows = cursor.fetchall()
            batch.append({"question": q, "sql": sql, "rows": str(rows)})

            # process when batch is full
            if len(batch) == batch_size:
                results = await llm_validate_batch(batch)
                for pair, is_valid in zip(batch, results):
                    if is_valid == "True":
                        qa_data.append(
                            {
                                "input": pair["question"],
                                "sql": pair["sql"],
                                "target": pair["rows"],
                            }
                        )
                    else:
                        print(f"Filtered out incorrect result for: {pair['question']}")
                batch = []
        except Exception as e:
            print(f"Skipping invalid SQL: {sql} ({e})")

    # process leftover batch
    if batch:
        results = await llm_validate_batch(batch)
        for pair, is_valid in zip(batch, results):
            if is_valid == "True":
                qa_data.append(
                    {
                        "input": pair["question"],
                        "sql": pair["sql"],
                        "target": pair["rows"],
                    }
                )
            else:
                print(f"Filtered out incorrect result for: {pair['question']}")

    conn.close()
    return qa_data

# {{/docs-fragment validate_sql}}

@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
    df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])

    csv_file = "qa_dataset.csv"
    df.to_csv(csv_file, index=False)

    return await File.from_local(csv_file)

# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
    num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
    db_file, _ = await data_ingestion()
    schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)

    per_chunk_samples = max(1, num_samples // len(schema_chunks))
    final_qa_data = []

    for chunk in schema_chunks:
        qa_list = await generate_questions_and_sql(
            schema=chunk,
            num_samples=per_chunk_samples,
            batch_size=batch_size,
        )
        qa_data = await validate_sql(db_file, qa_list, batch_size)
        final_qa_data.extend(qa_data)

    csv_file = await save_to_csv(final_qa_data)
    return csv_file

# {{/docs-fragment build_eval_dataset}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(build_eval_dataset)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py*

### Validation and quality control

Each generated SQL query runs against the database, and another LLM double-checks that the result matches the intent of the natural language question.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
#    "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///

import sqlite3

import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel

class QAItem(BaseModel):
    question: str
    sql: str

class QAList(BaseModel):
    items: list[QAItem]

# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
    """
    Download the SQLite DB, extract schema info (columns + sample rows),
    then split it into chunks with up to `tables_per_chunk` tables each.
    """
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    tables = cursor.execute(
        "SELECT name FROM sqlite_master WHERE type='table';"
    ).fetchall()

    schema_blocks = []
    for table in tables:
        table_name = table[0]

        # columns
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = [col[1] for col in cursor.fetchall()]
        block = f"Table: {table_name}({', '.join(columns)})"

        # sample rows
        cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
        rows = cursor.fetchall()
        if rows:
            block += "\nSample rows:\n"
            for row in rows:
                block += f"{row}\n"

        schema_blocks.append(block)

    conn.close()

    chunks = []
    current_chunk = []
    for block in schema_blocks:
        current_chunk.append(block)
        if len(current_chunk) >= tables_per_chunk:
            chunks.append("\n".join(current_chunk))
            current_chunk = []
    if current_chunk:
        chunks.append("\n".join(current_chunk))

    return chunks

# {{/docs-fragment get_and_split_schema}}

# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
    schema: str, num_samples: int, batch_size: int
) -> QAList:
    llm = OpenAI(model="gpt-4.1")

    prompt_tmpl = PromptTemplate(
        """Prompt: You are helping build a Text-to-SQL dataset.

Here is the database schema:
{schema}

Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.

Reasoning process (you must follow this internally):

- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.

Final Output:
Return only a JSON object with one field:

- "items": a list of {num} objects, each with:
    - "question": the natural language question
    - "sql": the corresponding SQL query
"""
    )

    all_items: list[QAItem] = []

    # batch generation
    for start in range(0, num_samples, batch_size):
        current_num = min(batch_size, num_samples - start)
        response = llm.structured_predict(
            QAList,
            prompt_tmpl,
            schema=schema,
            num=current_num,
        )
        all_items.extend(response.items)

    # deduplicate
    seen = set()
    unique_items: list[QAItem] = []
    for item in all_items:
        key = (item.question.strip().lower(), item.sql.strip().lower())
        if key not in seen:
            seen.add(key)
            unique_items.append(item)

    return QAList(items=unique_items[:num_samples])

# {{/docs-fragment generate_questions_and_sql}}

@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
    """Validate a batch of question/sql/result dicts using one LLM call."""
    batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""

    for i, pair in enumerate(pairs, start=1):
        batch_prompt += f"""
Example {i}:
Question:
{pair['question']}

SQL:
{pair['sql']}

Result:
{pair['rows']}
---
"""

    llm = OpenAI(model="gpt-4.1")
    resp = await llm.acomplete(batch_prompt)

    # Expect exactly one True/False per example
    results = [
        line.strip()
        for line in resp.text.splitlines()
        if line.strip() in ("True", "False")
    ]
    return results

# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
    db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
    await db_file.download(local_path="local_db.sqlite")
    conn = sqlite3.connect("local_db.sqlite")
    cursor = conn.cursor()

    qa_data = []
    batch = []

    for pair in question_sql_pairs.items:
        q, sql = pair.question, pair.sql
        try:
            cursor.execute(sql)
            rows = cursor.fetchall()
            batch.append({"question": q, "sql": sql, "rows": str(rows)})

            # process when batch is full
            if len(batch) == batch_size:
                results = await llm_validate_batch(batch)
                for pair, is_valid in zip(batch, results):
                    if is_valid == "True":
                        qa_data.append(
                            {
                                "input": pair["question"],
                                "sql": pair["sql"],
                                "target": pair["rows"],
                            }
                        )
                    else:
                        print(f"Filtered out incorrect result for: {pair['question']}")
                batch = []
        except Exception as e:
            print(f"Skipping invalid SQL: {sql} ({e})")

    # process leftover batch
    if batch:
        results = await llm_validate_batch(batch)
        for pair, is_valid in zip(batch, results):
            if is_valid == "True":
                qa_data.append(
                    {
                        "input": pair["question"],
                        "sql": pair["sql"],
                        "target": pair["rows"],
                    }
                )
            else:
                print(f"Filtered out incorrect result for: {pair['question']}")

    conn.close()
    return qa_data

# {{/docs-fragment validate_sql}}

@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
    df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])

    csv_file = "qa_dataset.csv"
    df.to_csv(csv_file, index=False)

    return await File.from_local(csv_file)

# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
    num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
    db_file, _ = await data_ingestion()
    schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)

    per_chunk_samples = max(1, num_samples // len(schema_chunks))
    final_qa_data = []

    for chunk in schema_chunks:
        qa_list = await generate_questions_and_sql(
            schema=chunk,
            num_samples=per_chunk_samples,
            batch_size=batch_size,
        )
        qa_data = await validate_sql(db_file, qa_list, batch_size)
        final_qa_data.extend(qa_data)

    csv_file = await save_to_csv(final_qa_data)
    return csv_file

# {{/docs-fragment build_eval_dataset}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(build_eval_dataset)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py*

Even with automated checks, human review remains critical. Since this dataset serves as the ground truth, mislabeled pairs can distort evaluation. For production use, always invest in human-in-the-loop review.

Support for human-in-the-loop pipelines is coming soon in Flyte 2!

## Optimizing prompts

With the QA dataset in place, we can turn to prompt optimization. The idea: start from a baseline prompt, generate new variants, and measure whether accuracy improves.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "sqlalchemy>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Val/Test split
    df_renamed = df.rename(columns={"input": "question", "target": "answer"})

    n = len(df_renamed)
    split = n // 2

    df_val = df_renamed.iloc[:split]
    df_test = df_renamed.iloc[split:]

    return df_val, df_test

@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    if retrieved_rows:
        # Get the structured result and stringify
        return str(retrieved_rows[0].node.metadata["result"])

    return ""

async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    db_file: File,
    table_infos: list[TableInfo | None],
    vector_index_dir: Dir,
) -> dict:
    # Generate response from target model
    table_context = await retrieve_tables(
        question, table_infos, db_file, vector_index_dir
    )
    sql = await generate_sql(
        question,
        table_context,
        target_model_config.model_name,
        target_model_config.prompt,
    )
    sql = sql.replace("sql\n", "")

    try:
        response = await generate_response(db_file, sql)
    except Exception as e:
        print(f"Failed to generate response for question {question}: {e}")
        response = None

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                query_str=question,
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "sql": sql,
        "is_correct": verdict_clean == "true",
    }

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    sql,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
    db_file,
    table_infos,
    vector_index_dir,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
                db_file,
                table_infos,
                vector_index_dir,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{html.escape(sql)}</td>
                    <td>{result['model_response']}</td>
                    <td>{result['sql']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

@dataclass
class DatabaseConfig:
    csv_zip_path: str
    search_glob: str
    concurrency: int
    model: str

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
    db_config: DatabaseConfig,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Ground Truth Answer</th>
                    <th>Ground Truth SQL</th>
                    <th>Model Response</th>
                    <th>Model SQL</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    db_file, table_infos = await data_ingestion(
        db_config.csv_zip_path,
        db_config.search_glob,
        db_config.concurrency,
        db_config.model,
    )

    vector_index_dir = await index_all_tables(db_file)

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            row.sql,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
            db_file,
            table_infos,
            vector_index_dir,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_val: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
    db_config: DatabaseConfig,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_val,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_val,
                target_model_config,
                review_model_config,
                concurrency,
                db_config,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    ground_truth_csv: File | str = "/root/ground_truth.csv",
    db_config: DatabaseConfig = DatabaseConfig(
        csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
        search_glob="WikiTableQuestions/csv/200-csv/*.csv",
        concurrency=5,
        model="gpt-4o-mini",
    ),
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""Given an input question, create a syntactically correct {dialect} query to run.

Schema:
{schema}

Question: {query_str}

SQL query to run:
""",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.

- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".

Question:
{query_str}

Ground Truth:
{answer}

Model Response:
{response}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.

<EXAMPLE>
<SCHEMA>
artists(id, name)
albums(id, title, artist_id, release_year)
</SCHEMA>
<QUESTION>
How many albums did The Beatles release?
</QUESTION>
<ANSWER>
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 5,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
        ground_truth_csv = await File.from_local(ground_truth_csv)

    df_val, df_test = await data_prep(ground_truth_csv)

    best_prompt, val_accuracy = await prompt_optimizer(
        df_val,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
        db_config,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

    return {
        "best_prompt": best_prompt,
        "validation_accuracy": val_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py*

### Evaluation pipeline

We evaluate each prompt variant against the golden dataset, split into validation and test sets, and record accuracy metrics in real time.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "sqlalchemy>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Val/Test split
    df_renamed = df.rename(columns={"input": "question", "target": "answer"})

    n = len(df_renamed)
    split = n // 2

    df_val = df_renamed.iloc[:split]
    df_test = df_renamed.iloc[split:]

    return df_val, df_test

@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    if retrieved_rows:
        # Get the structured result and stringify
        return str(retrieved_rows[0].node.metadata["result"])

    return ""

async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    db_file: File,
    table_infos: list[TableInfo | None],
    vector_index_dir: Dir,
) -> dict:
    # Generate response from target model
    table_context = await retrieve_tables(
        question, table_infos, db_file, vector_index_dir
    )
    sql = await generate_sql(
        question,
        table_context,
        target_model_config.model_name,
        target_model_config.prompt,
    )
    sql = sql.replace("sql\n", "")

    try:
        response = await generate_response(db_file, sql)
    except Exception as e:
        print(f"Failed to generate response for question {question}: {e}")
        response = None

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                query_str=question,
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "sql": sql,
        "is_correct": verdict_clean == "true",
    }

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    sql,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
    db_file,
    table_infos,
    vector_index_dir,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
                db_file,
                table_infos,
                vector_index_dir,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{html.escape(sql)}</td>
                    <td>{result['model_response']}</td>
                    <td>{result['sql']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

@dataclass
class DatabaseConfig:
    csv_zip_path: str
    search_glob: str
    concurrency: int
    model: str

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
    db_config: DatabaseConfig,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Ground Truth Answer</th>
                    <th>Ground Truth SQL</th>
                    <th>Model Response</th>
                    <th>Model SQL</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    db_file, table_infos = await data_ingestion(
        db_config.csv_zip_path,
        db_config.search_glob,
        db_config.concurrency,
        db_config.model,
    )

    vector_index_dir = await index_all_tables(db_file)

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            row.sql,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
            db_file,
            table_infos,
            vector_index_dir,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_val: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
    db_config: DatabaseConfig,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_val,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_val,
                target_model_config,
                review_model_config,
                concurrency,
                db_config,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    ground_truth_csv: File | str = "/root/ground_truth.csv",
    db_config: DatabaseConfig = DatabaseConfig(
        csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
        search_glob="WikiTableQuestions/csv/200-csv/*.csv",
        concurrency=5,
        model="gpt-4o-mini",
    ),
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""Given an input question, create a syntactically correct {dialect} query to run.

Schema:
{schema}

Question: {query_str}

SQL query to run:
""",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.

- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".

Question:
{query_str}

Ground Truth:
{answer}

Model Response:
{response}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.

<EXAMPLE>
<SCHEMA>
artists(id, name)
albums(id, title, artist_id, release_year)
</SCHEMA>
<QUESTION>
How many albums did The Beatles release?
</QUESTION>
<ANSWER>
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 5,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
        ground_truth_csv = await File.from_local(ground_truth_csv)

    df_val, df_test = await data_prep(ground_truth_csv)

    best_prompt, val_accuracy = await prompt_optimizer(
        df_val,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
        db_config,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

    return {
        "best_prompt": best_prompt,
        "validation_accuracy": val_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py*

Here's how prompt accuracy evolves over time, as shown in the UI report:

![Prompt accuracies](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/text-to-sql/prompt_accuracies.png)

### Iterative optimization

An optimizer LLM proposes new prompts by analyzing patterns in successful and failed generations. Each candidate runs through the evaluation loop, and we select the best performer.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas>=2.0.0",
#    "sqlalchemy>=2.0.0",
#    "llama-index-core>=0.11.0",
#    "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Val/Test split
    df_renamed = df.rename(columns={"input": "question", "target": "answer"})

    n = len(df_renamed)
    split = n // 2

    df_val = df_renamed.iloc[:split]
    df_test = df_renamed.iloc[split:]

    return df_val, df_test

@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
    await db_file.download(local_path="local_db.sqlite")

    engine = create_engine("sqlite:///local_db.sqlite")
    sql_database = SQLDatabase(engine)
    sql_retriever = SQLRetriever(sql_database)

    retrieved_rows = sql_retriever.retrieve(sql)

    if retrieved_rows:
        # Get the structured result and stringify
        return str(retrieved_rows[0].node.metadata["result"])

    return ""

async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    db_file: File,
    table_infos: list[TableInfo | None],
    vector_index_dir: Dir,
) -> dict:
    # Generate response from target model
    table_context = await retrieve_tables(
        question, table_infos, db_file, vector_index_dir
    )
    sql = await generate_sql(
        question,
        table_context,
        target_model_config.model_name,
        target_model_config.prompt,
    )
    sql = sql.replace("sql\n", "")

    try:
        response = await generate_response(db_file, sql)
    except Exception as e:
        print(f"Failed to generate response for question {question}: {e}")
        response = None

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                query_str=question,
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "sql": sql,
        "is_correct": verdict_clean == "true",
    }

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    sql,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
    db_file,
    table_infos,
    vector_index_dir,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
                db_file,
                table_infos,
                vector_index_dir,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{html.escape(sql)}</td>
                    <td>{result['model_response']}</td>
                    <td>{result['sql']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

@dataclass
class DatabaseConfig:
    csv_zip_path: str
    search_glob: str
    concurrency: int
    model: str

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
    db_config: DatabaseConfig,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Ground Truth Answer</th>
                    <th>Ground Truth SQL</th>
                    <th>Model Response</th>
                    <th>Model SQL</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    db_file, table_infos = await data_ingestion(
        db_config.csv_zip_path,
        db_config.search_glob,
        db_config.concurrency,
        db_config.model,
    )

    vector_index_dir = await index_all_tables(db_file)

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            row.sql,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
            db_file,
            table_infos,
            vector_index_dir,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_val: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
    db_config: DatabaseConfig,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_val,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_val,
                target_model_config,
                review_model_config,
                concurrency,
                db_config,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    ground_truth_csv: File | str = "/root/ground_truth.csv",
    db_config: DatabaseConfig = DatabaseConfig(
        csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
        search_glob="WikiTableQuestions/csv/200-csv/*.csv",
        concurrency=5,
        model="gpt-4o-mini",
    ),
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""Given an input question, create a syntactically correct {dialect} query to run.

Schema:
{schema}

Question: {query_str}

SQL query to run:
""",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.

- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".

Question:
{query_str}

Ground Truth:
{answer}

Model Response:
{response}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.

<EXAMPLE>
<SCHEMA>
artists(id, name)
albums(id, title, artist_id, release_year)
</SCHEMA>
<QUESTION>
How many albums did The Beatles release?
</QUESTION>
<ANSWER>
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 5,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
        ground_truth_csv = await File.from_local(ground_truth_csv)

    df_val, df_test = await data_prep(ground_truth_csv)

    best_prompt, val_accuracy = await prompt_optimizer(
        df_val,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
        db_config,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
            db_config,
        )

    return {
        "best_prompt": best_prompt,
        "validation_accuracy": val_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py*

On paper, this creates a continuous improvement cycle: baseline → new variants → measured gains.

## Run it

To create the QA dataset:

```
python create_qa_dataset.py
```

To run the prompt optimization loop:

```
python optimizer.py
```

## What we observed

Prompt optimization didn't consistently lift SQL accuracy in this workflow. Accuracy plateaued near the baseline. But the process surfaced valuable lessons about what matters when building LLM-powered systems on real infrastructure.

- **Schema clarity matters**: CSV ingestion produced tables with overlapping names, creating ambiguity. This showed how schema design and metadata hygiene directly affect downstream evaluation.
- **Ground truth needs trust**: Because the dataset came from LLM outputs, noise remained even after filtering. Human review proved essential. Golden datasets need deliberate curation, not just automation.
- **Optimization needs context**: The optimizer couldn't “see” which examples failed, limiting its ability to improve. Feeding failures directly risks overfitting. A structured way to capture and reuse evaluation signals is the right long-term path.

Sometimes prompt tweaks alone can lift accuracy, but other times the real bottleneck lives in the data, the schema, or the evaluation loop. The lesson isn't "prompt optimization doesn't work", but that its impact depends on the system around it. Accuracy improves most reliably when prompts evolve alongside clean data, trusted evaluation, and observable feedback loops.

## The bigger lesson

Evaluation and optimization aren’t one-off experiments; they’re continuous processes. What makes them sustainable isn't a clever prompt, it’s the platform around it.

Systems succeed when they:

- **Observe** failures with clarity — track exactly what failed and why.
- **Remain durable** across iterations — run pipelines that are stable, reproducible, and comparable over time.

That's where Flyte 2 comes in. Prompt optimization is one lever, but it becomes powerful only when combined with:

- Clean, human-validated evaluation datasets.
- Systematic reporting and feedback loops.

**The real takeaway: improving LLM pipelines isn't about chasing the perfect prompt. It's about designing workflows with observability and durability at the core, so that every experiment compounds into long-term progress.**

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/auto_prompt_engineering ===

# Automatic prompt engineering

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/auto_prompt_engineering).

When building with LLMs and agents, the first prompt almost never works. We usually need several iterations before results are useful. Doing this manually is slow, inconsistent, and hard to reproduce.

Flyte turns prompt engineering into a systematic process. With Flyte we can:

- Generate candidate prompts automatically.
- Run evaluations in parallel.
- Track results in real time with built-in observability.
- Recover from failures without losing progress.
- Trace the lineage of every experiment for reproducibility.

And we're not limited to prompts. Just like [hyperparameter optimization](../hpo/_index) in ML, we can tune model temperature, retrieval strategies, tool usage, and more. Over time, this grows into full agentic evaluations, tracking not only prompts but also how agents behave, make decisions, and interact with their environment.

In this tutorial, we'll build an automated prompt engineering pipeline with Flyte, step by step.

## Set up the environment

First, let's configure our task environment.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

We need an API key to call GPT-4.1 (our optimization model). Add it as a Flyte secret:

```
flyte create secret openai_api_key <YOUR_OPENAI_API_KEY>
```

We also define CSS styles for live HTML reports that track prompt optimization in real time:

![Results](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/gifs/tutorials/prompt_engineering/results.gif)

## Prepare the evaluation dataset

Next, we define our golden dataset, a set of prompts with known outputs. This dataset is used to evaluate the quality of generated prompts.

For this tutorial, we use a small geometric shapes dataset. To keep it portable, the data prep task takes a CSV file (as a Flyte `File` or a string for files available remotely) and splits it into train and test subsets.

If you already have prompts and outputs in Google Sheets, simply export them as CSV with two columns: `input` and `target`.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

This approach works with any dataset. You can swap in your own with no extra dependencies.

## Define models

We use two models:

- **Target model** → the one we want to optimize.
- **Review model** → the one that evaluates candidate prompts.

First, we capture all model parameters in a dataclass:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

Then we define a Flyte `trace` to call the model. Unlike a task, a trace runs within the same runtime as the parent process. Since the model is hosted externally, this keeps the call lightweight but still observable.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

You can also host your own models on Union. For example, we deploy <code>gpt-oss-20b</code> using vLLM.

```
import union
from union.app.llm import VLLMApp
from flytekit.extras.accelerators import A10G

Model = union.Artifact(name="gpt-oss-20b")

image = union.ImageSpec(
    name="vllm-gpt-oss",
    builder="union",
    apt_packages=["build-essential", "wget", "gnupg"],
    packages=[
        "union[vllm]==0.1.191b0",
        "--pre vllm==0.10.1+gptoss \
        --extra-index-url https://wheels.vllm.ai/gpt-oss/ \
        --extra-index-url https://download.pytorch.org/whl/nightly/cu128 \
        --index-strategy unsafe-best-match",
    ],
).with_commands(
    [
        "wget https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb",
        "dpkg -i cuda-keyring_1.1-1_all.deb",
        "apt-get update",
        "apt-get install -y cuda-toolkit-12-8",
        "/usr/local/cuda/bin/nvcc --version",
        "chown -R union /root",
        "chown -R union /home",
    ]
)

gpt_oss_app = VLLMApp(
    name="gpt-oss-20b-vllm",
    model=Model.query(),
    model_id="gpt-oss",
    container_image=image,
    requests=union.Resources(cpu="5", mem="26Gi", gpu="1", ephemeral_storage="150Gi"),
    accelerator=A10G,
    scaledown_after=300,
    stream_model=True,
    requires_auth=False,
    extra_args="--async-scheduling",
    env={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/gpt_oss.py*

<p>We use an <code>A10G</code> GPU instance, and with streaming, you can load model weights directly into GPU memory instead of downloading the weights to disk first, then loading to GPU memory.</p>

<p>To deploy the model, cache the model from HuggingFace with a Union artifact:</p>

<pre>
union cache model-from-hf \
    --hf-token-key hf-api-key \
    --artifact-name gpt-oss-20b \
    --cpu 2 \
    --mem 8Gi \
    --ephemeral-storage 100Gi openai/gpt-oss-20b
</pre>

Then deploy it:

<pre>
union deploy apps gpt_oss.py gpt-oss-20b-vllm
</pre>

When using a hosted model, just provide its <code>hosted_model_uri</code> in <code>ModelConfig</code>. All inference happens locally, so your data never leaves your environment.

Finally, we wrap the trace in a task to call both target and review models:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

## Evaluate prompts

We now define the evaluation process.

Each prompt in the dataset is tested in parallel, but we use a semaphore to control concurrency. A helper function ties together the `generate_and_review` task with an HTML report template. Using `asyncio.gather`, we evaluate multiple prompts at once.

The function measures accuracy as the fraction of responses that match the ground truth. Flyte streams these results to the UI, so you can watch evaluations happen live.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

## Optimize prompts

Optimization builds on evaluation. We give the optimizer model:

- the history of prompts tested so far, and
- their accuracies.

The model then proposes a new prompt.

We start with a _baseline_ evaluation using the user-provided prompt. Then for each iteration, the optimizer suggests a new prompt, which we evaluate and log. We continue until we hit the iteration limit.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

At the end, we return the best prompt and its accuracy. The report shows how accuracy improves over time and which prompts were tested.

![Report](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/gifs/tutorials/prompt_engineering/prompt_accuracies.png)

## Build the full pipeline

The entrypoint task wires everything together:

- Accepts model configs, dataset, iteration count, and concurrency.
- Runs data preparation.
- Calls the optimizer.
- Evaluates both baseline and best prompts on the test set.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

## Run it

We add a simple main block so we can run the workflow as a script:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pandas==2.3.1",
#    "pyarrow==21.0.0",
#    "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union

import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File

env = flyte.TaskEnvironment(
    name="auto-prompt-engineering",
    image=flyte.Image.from_uv_script(
        __file__, name="auto-prompt-engineering", pre=True
    ),
    secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
    resources=flyte.Resources(cpu=1),
)

CSS = """
<style>
    body {
        font-family: 'Segoe UI', Roboto, Arial, sans-serif;
    }
    .results-table {
        border-collapse: collapse;
        width: 100%;
        box-shadow: 0 2px 5px rgba(0,0,0,0.1);
        font-size: 14px;
    }
    .results-table th {
        background: linear-gradient(135deg, #4CAF50, #2E7D32);
        color: white;
        padding: 10px;
        text-align: left;
    }
    .results-table td {
        border: 1px solid #ddd;
        padding: 8px;
        vertical-align: top;
    }
    .results-table tr:nth-child(even) {background-color: #f9f9f9;}
    .results-table tr:hover {background-color: #f1f1f1;}
    .correct {color: #2E7D32; font-weight: bold;}
    .incorrect {color: #C62828; font-weight: bold;}
    .summary-card {
        background: #f9fbfd;
        padding: 14px 18px;
        border-radius: 8px;
        box-shadow: 0 1px 4px rgba(0,0,0,0.05);
        max-width: 800px;
        margin-top: 12px;
    }
    .summary-card h3 {
        margin-top: 0;
        color: #1e88e5;
        font-size: 16px;
    }
</style>
"""

# {{/docs-fragment env}}

# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
    The sheet should have columns: 'input' and 'target'.
    """
    df = pd.read_csv(
        await csv_file.download() if isinstance(csv_file, File) else csv_file
    )

    if "input" not in df.columns or "target" not in df.columns:
        raise ValueError("Sheet must contain 'input' and 'target' columns.")

    # Shuffle rows
    df = df.sample(frac=1, random_state=1234).reset_index(drop=True)

    # Train/Test split
    df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
    df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})

    return df_train, df_test

# {{/docs-fragment data_prep}}

# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
    model_name: str
    hosted_model_uri: Optional[str] = None
    temperature: float = 0.0
    max_tokens: Optional[int] = 1000
    timeout: int = 600
    prompt: str = ""

# {{/docs-fragment model_config}}

# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
    model_config: ModelConfig,
    messages: list[dict[str, str]],
) -> str:
    from litellm import acompletion

    response = await acompletion(
        model=model_config.model_name,
        api_base=model_config.hosted_model_uri,
        messages=messages,
        temperature=model_config.temperature,
        timeout=model_config.timeout,
        max_tokens=model_config.max_tokens,
    )
    return response.choices[0].message["content"]

# {{/docs-fragment call_model}}

# {{docs-fragment generate_and_review}}
async def generate_and_review(
    index: int,
    question: str,
    answer: str,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
) -> dict:
    # Generate response from target model
    response = await call_model(
        target_model_config,
        [
            {"role": "system", "content": target_model_config.prompt},
            {"role": "user", "content": question},
        ],
    )

    # Format review prompt with response + answer
    review_messages = [
        {
            "role": "system",
            "content": review_model_config.prompt.format(
                response=response,
                answer=answer,
            ),
        }
    ]
    verdict = await call_model(review_model_config, review_messages)

    # Normalize verdict
    verdict_clean = verdict.strip().lower()
    if verdict_clean not in {"true", "false"}:
        verdict_clean = "not sure"

    return {
        "index": index,
        "model_response": response,
        "is_correct": verdict_clean == "true",
    }

# {{/docs-fragment generate_and_review}}

async def run_grouped_task(
    i,
    index,
    question,
    answer,
    semaphore,
    target_model_config,
    review_model_config,
    counter,
    counter_lock,
):
    async with semaphore:
        with flyte.group(name=f"row-{i}"):
            result = await generate_and_review(
                index,
                question,
                answer,
                target_model_config,
                review_model_config,
            )

            async with counter_lock:
                # Update counters
                counter["processed"] += 1
                if result["is_correct"]:
                    counter["correct"] += 1
                    correct_html = "<span class='correct'>✔ Yes</span>"
                else:
                    correct_html = "<span class='incorrect'>✘ No</span>"

                # Calculate accuracy
                accuracy_pct = (counter["correct"] / counter["processed"]) * 100

            # Update chart
            await flyte.report.log.aio(
                f"<script>updateAccuracy({accuracy_pct});</script>",
                do_flush=True,
            )

            # Add row to table
            await flyte.report.log.aio(
                f"""
                <tr>
                    <td>{html.escape(question)}</td>
                    <td>{html.escape(answer)}</td>
                    <td>{result['model_response']}</td>
                    <td>{correct_html}</td>
                </tr>
                """,
                do_flush=True,
            )

            return result

# {{docs-fragment evaluate_prompt}}
@env.task(report=True)
async def evaluate_prompt(
    df: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    concurrency: int,
) -> float:
    semaphore = asyncio.Semaphore(concurrency)
    counter = {"correct": 0, "processed": 0}
    counter_lock = asyncio.Lock()

    # Write initial HTML structure
    await flyte.report.log.aio(
        CSS
        + """
        <script>
            function updateAccuracy(percent) {
                const bar = document.getElementById('acc-bar');
                const label = document.getElementById('acc-label');
                bar.setAttribute('width', percent * 3);
                label.textContent = `Accuracy: ${percent.toFixed(1)}%`;
            }
        </script>

        <h2 style="margin-top:0;">Model Evaluation Results</h2>
        <h3>Live Accuracy</h3>
        <svg width="320" height="30" id="accuracy-chart">
            <defs>
                <linearGradient id="acc-gradient" x1="0" x2="1" y1="0" y2="0">
                    <stop offset="0%" stop-color="#66bb6a"/>
                    <stop offset="100%" stop-color="#2e7d32"/>
                </linearGradient>
            </defs>
            <rect width="300" height="20" fill="#ddd" rx="5" ry="5"></rect>
            <rect id="acc-bar" width="0" height="20" fill="url(#acc-gradient)" rx="5" ry="5"></rect>
            <text id="acc-label" x="150" y="15" font-size="12" font-weight="bold" text-anchor="middle" fill="#000">
                Accuracy: 0.0%
            </text>
        </svg>

        <table class="results-table">
            <thead>
                <tr>
                    <th>Question</th>
                    <th>Answer</th>
                    <th>Model Response</th>
                    <th>Correct?</th>
                </tr>
            </thead>
            <tbody>
        """,
        do_flush=True,
    )

    # Launch tasks concurrently
    tasks = [
        run_grouped_task(
            i,
            row.Index,
            row.question,
            row.answer,
            semaphore,
            target_model_config,
            review_model_config,
            counter,
            counter_lock,
        )
        for i, row in enumerate(df.itertuples(index=True))
    ]
    await asyncio.gather(*tasks)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    async with counter_lock:
        return (
            (counter["correct"] / counter["processed"]) if counter["processed"] else 0.0
        )

# {{/docs-fragment evaluate_prompt}}

@dataclass
class PromptResult:
    prompt: str
    accuracy: float

# {{docs-fragment prompt_optimizer}}
@env.task(report=True)
async def prompt_optimizer(
    df_train: pd.DataFrame,
    target_model_config: ModelConfig,
    review_model_config: ModelConfig,
    optimizer_model_config: ModelConfig,
    max_iterations: int,
    concurrency: int,
) -> tuple[str, float]:
    prompt_accuracies: list[PromptResult] = []

    # Send styling + table header immediately
    await flyte.report.log.aio(
        CSS
        + """
    <h2 style="margin-bottom:6px;">📊 Prompt Accuracy Comparison</h2>
    <table class="results-table">
        <thead>
            <tr>
                <th>Prompt</th>
                <th>Accuracy</th>
            </tr>
        </thead>
    <tbody>
    """,
        do_flush=True,
    )

    # Step 1: Evaluate starting prompt and stream row
    with flyte.group(name="baseline_evaluation"):
        starting_accuracy = await evaluate_prompt(
            df_train,
            target_model_config,
            review_model_config,
            concurrency,
        )
        prompt_accuracies.append(
            PromptResult(prompt=target_model_config.prompt, accuracy=starting_accuracy)
        )

        await _log_prompt_row(target_model_config.prompt, starting_accuracy)

    # Step 2: Optimize prompts one by one, streaming after each
    while len(prompt_accuracies) <= max_iterations:
        with flyte.group(name=f"prompt_optimization_step_{len(prompt_accuracies)}"):
            # Prepare prompt scores string for optimizer
            prompt_scores_str = "\n".join(
                f"{result.prompt}: {result.accuracy:.2f}"
                for result in sorted(prompt_accuracies, key=lambda x: x.accuracy)
            )

            optimizer_model_prompt = optimizer_model_config.prompt.format(
                prompt_scores_str=prompt_scores_str
            )
            response = await call_model(
                optimizer_model_config,
                [{"role": "system", "content": optimizer_model_prompt}],
            )
            response = response.strip()

            match = re.search(r"\[\[(.*?)\]\]", response, re.DOTALL)
            if not match:
                print("No new prompt found. Skipping.")
                continue

            new_prompt = match.group(1)
            target_model_config.prompt = new_prompt
            accuracy = await evaluate_prompt(
                df_train,
                target_model_config,
                review_model_config,
                concurrency,
            )
            prompt_accuracies.append(PromptResult(prompt=new_prompt, accuracy=accuracy))

            # Log this new prompt row immediately
            await _log_prompt_row(new_prompt, accuracy)

    # Close table
    await flyte.report.log.aio("</tbody></table>", do_flush=True)

    # Find best
    best_result = max(prompt_accuracies, key=lambda x: x.accuracy)
    improvement = best_result.accuracy - starting_accuracy

    # Summary
    await flyte.report.log.aio(
        f"""
    <div class="summary-card">
        <h3>🏆 Summary</h3>
        <p><strong>Best Prompt:</strong> {html.escape(best_result.prompt)}</p>
        <p><strong>Best Accuracy:</strong> {best_result.accuracy*100:.2f}%</p>
        <p><strong>Improvement Over Baseline:</strong> {improvement*100:.2f}%</p>
    </div>
    """,
        do_flush=True,
    )

    return best_result.prompt, best_result.accuracy

# {{/docs-fragment prompt_optimizer}}

async def _log_prompt_row(prompt: str, accuracy: float):
    """Helper to log a single prompt/accuracy row to Flyte report."""
    pct = accuracy * 100
    if pct > 80:
        color = "linear-gradient(90deg, #4CAF50, #81C784)"
    elif pct > 60:
        color = "linear-gradient(90deg, #FFC107, #FFD54F)"
    else:
        color = "linear-gradient(90deg, #F44336, #E57373)"

    await flyte.report.log.aio(
        f"""
        <tr>
            <td>{html.escape(prompt)}</td>
            <td>
                {pct:.1f}%
                <div class="accuracy-bar-container">
                    <div class="accuracy-bar" style="width:{pct*1.6}px; background:{color};"></div>
                </div>
            </td>
        </tr>
        """,
        do_flush=True,
    )

# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
    csv_file: File | str = "https://dub.sh/geometric-shapes",
    target_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="Solve the given problem about geometric shapes. Think step by step.",
        max_tokens=10000,
    ),
    review_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1-mini",
        hosted_model_uri=None,
        prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.

Model Response:
{response}

Ground Truth:
{answer}
""",
    ),
    optimizer_model_config: ModelConfig = ModelConfig(
        model_name="gpt-4.1",
        hosted_model_uri=None,
        temperature=0.7,
        max_tokens=None,
        prompt="""
<EXPLANATION>
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
</EXPLANATION>

<PROMPTS>
{prompt_scores_str}
</PROMPTS>

Each prompt was used together with a problem statement around geometric shapes.

<EXAMPLE>
<QUESTION>
This SVG path element <path d="M 55.57,80.69 L 57.38,65.80 M 57.38,65.80 L 48.90,57.46 M 48.90,57.46 L 45.58,47.78 M 45.58,47.78 L 53.25,36.07 L 66.29,48.90 L 78.69,61.09 L 55.57,80.69"/> draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
</QUESTION>
<ANSWER>
(B)
</ANSWER>
</EXAMPLE>

<TASK>
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
</TASK>

<RULES>
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
</RULES>
""",
    ),
    max_iterations: int = 3,
    concurrency: int = 10,
) -> dict[str, Union[str, float]]:
    if isinstance(csv_file, str) and os.path.isfile(csv_file):
        csv_file = await File.from_local(csv_file)

    df_train, df_test = await data_prep(csv_file)

    best_prompt, training_accuracy = await prompt_optimizer(
        df_train,
        target_model_config,
        review_model_config,
        optimizer_model_config,
        max_iterations,
        concurrency,
    )

    with flyte.group(name="test_data_evaluation"):
        baseline_test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

        target_model_config.prompt = best_prompt
        test_accuracy = await evaluate_prompt(
            df_test,
            target_model_config,
            review_model_config,
            concurrency,
        )

    return {
        "best_prompt": best_prompt,
        "training_accuracy": training_accuracy,
        "baseline_test_accuracy": baseline_test_accuracy,
        "test_accuracy": test_accuracy,
    }

# {{/docs-fragment auto_prompt_engineering}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(auto_prompt_engineering)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py*

Run it with:

```
uv run optimizer.py
```

![Execution](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/gifs/tutorials/prompt_engineering/execution.gif)

## Why this matters

Most prompt engineering pipelines start as quick scripts or notebooks. They're fine for experimenting, but they're difficult to scale, reproduce, or debug when things go wrong.

With Flyte 2, we get a more reliable setup:

- Run many evaluations in parallel with [async Python](https://www.union.ai/docs/v2/union/user-guide/flyte-2/async/page.md) or [native DSL](https://www.union.ai/docs/v2/union/user-guide/flyte-2/async/page.md).
- Watch accuracy improve in real time and link results back to the exact dataset, prompt, and model config used.
- Resume cleanly after failures without rerunning everything from scratch.
- Reuse the same pattern to tune other parameters like temperature, retrieval depth, or agent strategies, not just prompts.

## Next steps

You now have a working automated prompt engineering pipeline. Here’s how you can take it further:

- **Optimize beyond prompts**: Tune temperature, retrieval strategies, or tool usage just like prompts.
- **Expand evaluation metrics**: Add latency, cost, robustness, or diversity alongside accuracy.
- **Move toward agentic evaluation**: Instead of single prompts, test how agents plan, use tools, and recover from failures in long-horizon tasks.

With this foundation, prompt engineering becomes repeatable, observable, and scalable, ready for production-grade LLM and agent systems.

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/micro-batching ===

<!--

   This file was generated by Makefile.jupyter. Do not edit this file directly.

   The only parts of this file that should be edited are the front matter and the
   comment at the top of the file.

-->

# Batching strategies for efficient scaling

> [!NOTE]
> [View source on GitHub](https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/batching_patterns/batch_processing.ipynb) | [Run in Google Colab](https://colab.research.google.com/github/unionai/unionai-examples/blob/main/v2/tutorials/batching_patterns/batch_processing.ipynb)

This notebook demonstrates a production-ready pattern for processing millions of items efficiently using Flyte v2's advanced features. You'll learn how to build resilient, scalable workflows that can handle failures gracefully and optimize resource consumption.

## Use Case

**The Challenge:** Processing massive datasets (100K to 1M+ items) that require external API calls or long-running operations.

**Real-World Examples:**
- Web scraping large lists of URLs
- Batch inference on millions of data points
- Processing documents through external APIs
- ETL pipelines with rate-limited services
- Data validation against third-party services

**The Problem:** When you have so many inputs that you must:
1. Split them into batches 
2. Submit each batch to an external service and wait for completion
3. Handle failures without losing progress
4. Optimize resource usage across thousands of operations

**Why This Matters:** Without proper batching and checkpointing, a single failure in a million-item workflow could force you to restart from scratch, wasting compute resources and time.

## Goals

**Our Goals:**
1. **Resilience:** Mitigate the impact of batches that take longer or fail
2. **Determinism:** Make operations with external API dependencies predictable and resumable
3. **Efficiency:** Optimize resource consumption through container reuse and parallel processing
4. **Cost Savings:** Minimize wasted compute by checkpointing progress

## Solution Architecture

This example demonstrates a production-ready micro-batching pattern that combines some Union features, including:

### 1. Failure transparency with @flyte.trace
The `@flyte.trace` decorator creates automatic checkpoints:
- **What it does:** Records inputs and outputs of decorated functions
- **Why it matters:** If a task fails, it resumes from the last successful checkpoint
- **Result:** No re-execution of completed work

### 2. Reusable Containers for Efficiency
Instead of creating a new container for each task:
- **Container pools:** Pre-warmed replicas ready to handle work
- **Concurrent processing:** Each replica handles multiple items simultaneously
- **Automatic scaling:** Replicas scale between min/max based on workload
- **Resource optimization:** Dramatically reduced startup overhead

### Key Benefits:
- **Automatic checkpointing** at batch and operation boundaries  
- **Resume from last successful point** on any failure  
- **No wasted compute** - never re-execute completed work  
- **Massive parallelism** - process thousands of batches concurrently  
- **Cost efficient** - container reuse minimizes cold-start overhead  

### Architecture Flow:
```
1M items → Split into 1,000 batches (1K each)
         ↓
    Parallel processing across reusable container pool
         ↓
    Each batch: Submit → Poll → Checkpoint
         ↓
    Aggregate results from all batches
```

### Architecture Diagram

![Micro-batching Architecture](./images/micro-batching.png)

**Diagram shows:**
- Input data split into batches
- Reusable container pool 
- Concurrent processing within each replica 
- Submit and wait phases with `@flyte.trace` checkpoints
- Parallel execution across all batches

## Implementation

### Step 0: Set up the runtime
Prepare the runtime environment for execution

```python
!uv pip install --no-cache --prerelease=allow --upgrade "flyte>=2.0.0b52" "unionai-reuse>=0.1.10"
```

### Step 1: Initialize Flyte Configuration

Configure your connection to the Flyte cluster. This tells Flyte where to run your workflows and how to build container images.

**Configuration Options:**
- `endpoint`: Your Flyte cluster URL
- `org`: Your organization name
- `project`: Project to organize workflows
- `domain`: Environment (development, staging, production)
- `image_builder`: Use "remote" to build images on the cluster (no local Docker required)

```python
# Initialize connection to your Flyte cluster
# Replace these values with your own cluster details

import flyte
flyte.init(
    endpoint="https://<MY_TENANT_HOST>",  # Your Union cluster URL
    org="demo",                                     # Your organization
    project="flytesnacks",                               # Your project name
    domain="development",                           # Environment: development/staging/production
    image_builder="remote",                         # Build images on cluster (no local Docker needed)
    auth_type="DeviceFlow",
)
```

```python
# Import required libraries
import asyncio                          # For concurrent async operations
from datetime import timedelta          # For time-based configuration
from pathlib import Path                # For file path handling
from typing import Dict, List           # For type hints

import flyte                            # Main Flyte SDK
from flyte.remote import Run            # For interacting with remote executions
```

```python
# ============================================
# CONFIGURATION: Adjust these for your use case
# ============================================

# Total number of items to process
# In production, this could be the size of your dataset
NUMBER_OF_INPUTS = 1_000_000  # 1 million items

# Size of each batch
# Considerations for choosing batch size:
# - Larger batches: Fewer tasks, more memory per task
# - Smaller batches: More granular checkpointing, better parallelism
# - Recommendation: Start with 1000-10000 depending on item complexity
BATCH_SIZE = 1000

# Example calculations:
# 1M items ÷ 1K batch = 1,000 parallel batch tasks
# Each batch processes 1K items concurrently within its container
```

### Step 2: Define Container Image

Create a container image specification with all required dependencies.

**Key Dependencies:**
- `flyte>=2.0.0b52`: Flyte v2 SDK for workflow orchestration
- `unionai-reuse>=0.1.10`: Required for Reusable Containers feature

**Note:** You can add any additional packages your tasks need (e.g., `httpx` for API calls, `beautifulsoup4` for web scraping, etc.)

```python
# Define the container image that will run our tasks
# This image will be built once and shared across all task executions
image = (
    flyte.Image.from_debian_base()  # Start with a lightweight Debian base
    .with_pip_packages(
        "flyte>=2.0.0b52",          # Flyte v2 SDK
        "unionai-reuse>=0.1.10"      # Required for reusable containers
        # Add your own dependencies here
    )
)
```

### Step 3: Define Task Environments

Task environments encapsulate the runtime configuration for tasks. We'll create one with **Reusable Containers** for efficient batch processing.

#### What are Reusable Containers?

Instead of creating a new Kubernetes Pod for every task execution, Reusable Containers maintain a pool of pre-warmed replicas that can handle multiple tasks sequentially or concurrently.

**Benefits:**
- **Faster execution:** No container startup overhead (can save 10-60 seconds per task)
- **Better resource utilization:** Containers stay warm and handle multiple items
- **Cost savings:** Especially significant for tasks with expensive initialization
- **Concurrent processing:** Each replica can process multiple items simultaneously

```python
# Create a TaskEnvironment with Reusable Containers for batch processing
batch_env = flyte.TaskEnvironment(
    name="batch_processor",  # Name used for Kubernetes pods: batch_processor-<hash>

    # Resource allocation per replica (per pod)
    resources=flyte.Resources(
        memory="2Gi",  # Memory per replica
        cpu="1"        # CPU cores per replica
    ),

    # Reusable container configuration
    reusable=flyte.ReusePolicy(
        # Number of replica pods to maintain
        # (min, max) - scales between these values based on workload
        replicas=(3, 10),  # Start with 3, scale up to 10 as needed

        # Concurrency: How many items each replica processes simultaneously
        # Higher = more throughput per replica, but more memory usage
        concurrency=5,  # Each pod handles 5 concurrent operations

        # How long idle replicas stay alive before being torn down
        idle_ttl=timedelta(minutes=5),  # Keep warm for 5 minutes
    ),

    # Use the container image we defined earlier
    image=image,
)

# CAPACITY CALCULATION:
# With replicas=(3, 10) and concurrency=5:
# - Minimum concurrent processing: 3 replicas × 5 concurrency = 15 operations
# - Maximum concurrent processing: 10 replicas × 5 concurrency = 50 operations
#
# For 1,000 batches with these settings:
# - Best case: 50 batches processing simultaneously
# - Time to process all: ~20 rounds of execution
```

#### Understanding TaskEnvironment Parameters

**name:** 
- Used as the prefix for Kubernetes pod names
- Example: `batch_processor-abc123`

**resources:** 
- Compute resources allocated to *each replica*
- Set based on your task's memory and CPU needs
- Tip: Monitor actual usage and adjust accordingly

**replicas (min, max):**
- Flyte autoscales between these values based on workload
- More replicas = more parallel processing capacity
- Consider your cluster's capacity and quota limits

**concurrency:**
- Number of async operations each Python process (per pod) handles simultaneously
- This is *within* each replica, not across replicas
- Higher values increase throughput but require more memory
- Best for I/O-bound tasks (API calls, web scraping)
- For CPU-bound tasks, keep this lower (1-2)

**idle_ttl:**
- Time replicas stay alive without active work before shutdown
- Longer TTL = faster subsequent executions, higher resource costs
- Shorter TTL = lower costs, potential startup delays
- Recommendation: 5-15 minutes for typical workloads

**image:**
- The container image specification with all dependencies
- Built once and reused across all task executions

#### Creating the Orchestrator Environment

The orchestrator task coordinates all batch processing but doesn't need container reuse since it only runs once per workflow execution.

```python
# Create a separate environment for the orchestrator task
orchestrator_env = flyte.TaskEnvironment(
    name="orchestrator",

    # depends_on: Use the same image as batch_env (avoids rebuilding)
    # Flyte will build batch_env's image first, then reuse it here.
    # This is also needed as the orchestrator task calls batch tasks that use batch_env.
    depends_on=[batch_env],

    # Orchestrator needs more memory to track all batch executions
    # but doesn't need reusable containers (runs once per workflow)
    resources=flyte.Resources(
        memory="4Gi",  # More memory to manage many parallel batches
        cpu="1"        # Single CPU is sufficient for orchestration
    ),

    image=image,  # Same image, different resource allocation
)
```

#### Why Two Environments?

**Separation of Concerns:**
- **Batch Environment:** Does the heavy lifting (processing items)
  - Needs reusable containers for efficiency
  - Scales horizontally (many replicas)
  - I/O bound operations benefit from concurrency

- **Orchestrator Environment:** Coordinates the workflow
  - Runs once per workflow execution
  - Doesn't need container reuse
  - Needs enough memory to track all batches
  - CPU bound for coordination logic

This separation optimizes both cost and performance.

### Step 4: Define External Service Interactions

These helper functions simulate interactions with external services (APIs, web scraping, etc.). 

```python
async def submit_to_service(request_id: int) -> str:
    """
    Submit a request to an external service and get a job ID.

    This simulates the "submit" phase of a batch job pattern where you:
    1. Send data to an external service
    2. Receive a job/task ID for tracking
    3. Use that ID to poll for completion later

    PRODUCTION IMPLEMENTATION:
    Replace this simulation with your actual service call:

    ```python
    async with httpx.AsyncClient() as client:
        response = await client.post(
            "https://your-service.com/api/submit",
            json={"request_id": request_id, "data": your_data},
            timeout=30.0
        )
        response.raise_for_status()
        return response.json()["job_id"]
    ```

    Args:
        request_id: Unique identifier for this request

    Returns:
        job_id: Identifier to track this job's progress
    """
    await asyncio.sleep(0.01)  # Simulate network latency
    job_id = f"job_{request_id}"
    return job_id

async def poll_job_status(job_id: str, request_id: int) -> int:
    """
    Poll an external service until a job completes and return results.

    This simulates the "wait" phase where you:
    1. Repeatedly check if a submitted job has completed
    2. Wait between checks to avoid overwhelming the service
    3. Return the final result when ready

    PRODUCTION IMPLEMENTATION:
    Replace this simulation with your actual polling logic:

    ```python
    async with httpx.AsyncClient() as client:
        max_attempts = 60  # 5 minutes with 5-second intervals

        for attempt in range(max_attempts):
            response = await client.get(
                f"https://your-service.com/api/status/{job_id}",
                timeout=10.0
            )
            response.raise_for_status()
            status = response.json()

            if status["state"] == "completed":
                return status["result"]
            elif status["state"] == "failed":
                raise Exception(f"Job {job_id} failed: {status['error']}")

            # Wait before next poll
            await asyncio.sleep(5)

        raise TimeoutError(f"Job {job_id} did not complete in time")
    ```

    Args:
        job_id: The job identifier from submit_to_service
        request_id: Original request ID for logging/tracking

    Returns:
        result: The processed result from the external service
    """
    await asyncio.sleep(0.05)  # Simulate polling + processing time
    return request_id * 2  # Dummy result

# IMPORTANT NOTES:
# 1. Both functions are async - they don't block while waiting
# 2. Add logging for debugging and monitoring
```

### Step 5: Implement the Batch Processing Task

This is the heart of the pattern. The `process_batch` task processes a batch of items with automatic checkpointing using `@flyte.trace`.

#### Key Concepts:

**Two-Phase Processing:**
1. **Submit Phase:** Send all items to external service concurrently
2. **Wait Phase:** Poll for completion of all submitted jobs

**Why @flyte.trace?**
- Creates checkpoints at phase boundaries
- If the task fails during wait phase, it resumes from there (doesn't re-submit)
- Enables forward recovery without re-execution

**Concurrency Pattern:**
- Uses `asyncio.gather()` to process all items in a batch simultaneously
- `return_exceptions=True` prevents one failure from stopping the batch
- Each phase completes fully before moving to the next

```python
@batch_env.task  # This task runs in the reusable container pool
async def process_batch(batch_start: int, batch_end: int) -> List[int]:
    """
    Process a single batch of items with checkpointed phases.

    This function demonstrates the core micro-batching pattern with:
    1. Two-phase processing (submit → wait)
    2. Automatic checkpointing via @flyte.trace
    3. Error handling without stopping the entire batch
    4. Concurrent processing within the batch

    Args:
        batch_start: Starting index for this batch (inclusive)
        batch_end: Ending index for this batch (exclusive)

    Returns:
        List of processed results (or -1 for failed items)

    Example:
        process_batch(0, 1000) processes items 0-999
        process_batch(1000, 2000) processes items 1000-1999
    """

    # ========================================
    # PHASE 1: SUBMIT ALL ITEMS TO SERVICE
    # ========================================
    @flyte.trace  # Creates a checkpoint after this phase completes
    async def submit_phase(items: List[int]) -> Dict[int, str]:
        """
        Submit all items concurrently and collect job IDs.

        This function:
        1. Launches submit_to_service() for ALL items simultaneously
        2. Waits for all submissions to complete with asyncio.gather()
        3. Handles errors gracefully (return_exceptions=True)
        4. Maps each request_id to its job_id (or None if failed)

        Why @flyte.trace here:
        - If this phase succeeds but wait_phase fails, we don't re-submit
        - Checkpointed data includes all job_ids for the wait phase
        - Forward recovery from exact failure point

        """

        job_ids = await asyncio.gather(
            *(submit_to_service(request_id=x) for x in items),
            return_exceptions=True  # Don't stop on individual failures
        )

        # Map request IDs to job IDs (or None for failures)
        job_mapping = {}
        for request_id, job_id in zip(items, job_ids):
            if isinstance(job_id, Exception):
                print(f"[ERROR] Submit failed for {request_id}: {job_id}")
                job_mapping[request_id] = None  # Mark as failed
            else:
                job_mapping[request_id] = job_id

        return job_mapping

    # ========================================
    # PHASE 2: WAIT FOR ALL JOBS TO COMPLETE
    # ========================================
    @flyte.trace  # Creates another checkpoint after this phase completes
    async def wait_phase(job_mapping: Dict[int, str]) -> List[int]:
        """
        Poll all submitted jobs until completion.

        This function:
        1. Takes the checkpointed job_mapping from submit_phase
        2. Polls all jobs concurrently
        3. Handles polling errors gracefully
        4. Returns final results

        WHY @flyte.trace HERE:
        - If polling fails partway through, we resume with cached job_mapping
        - Don't re-submit jobs that were already submitted
        - Each successful poll is checkpointed

        ERROR HANDLING:
        - Jobs that failed in submit_phase (None) are skipped
        - Polling failures are caught and marked as -1
        - The batch continues even if some items fail
        """
        # Poll ALL jobs concurrently
        results = await asyncio.gather(
            *(
                poll_job_status(job_id=job_id, request_id=request_id)
                if job_id is not None  # Only poll successfully submitted jobs
                else asyncio.sleep(0)   # Skip failed submissions
                for request_id, job_id in job_mapping.items()
            ),
            return_exceptions=True  # Don't stop on individual failures
        )

        # Process results and handle errors
        processed_results = []
        for request_id, result in zip(job_mapping.keys(), results):
            if isinstance(result, Exception):
                print(f"[ERROR] Wait failed for {request_id}: {result}")
                processed_results.append(-1)  # Mark as failed
            else:
                processed_results.append(result)

        return processed_results

    # ========================================
    # EXECUTE BOTH PHASES SEQUENTIALLY
    # ========================================
    # Create the list of items for this batch
    items = list(range(batch_start, batch_end))

    # Phase 1: Submit all items and get job IDs (checkpointed)
    job_mapping = await submit_phase(items)

    # Phase 2: Wait for all jobs to complete (checkpointed)
    results = await wait_phase(job_mapping)

    # Log batch completion stats
    successful = len([r for r in results if r != -1])
    print(f"Batch {batch_start}-{batch_end}: {successful}/{len(results)} successful")

    return results

# ========================================
# CHECKPOINT & RECOVERY BEHAVIOR
# ========================================
#
# Scenario 1: Task fails during submit_phase
# → Retries resume from last checkpoint
#
# Scenario 2: Task fails after submit_phase completes
# → Resumes directly to wait_phase with cached job_mapping
# → No re-submissions!
#
# Scenario 3: Task fails during wait_phase
# → Resumes wait_phase with cached job_mapping
# → Already-polled jobs are not polled again (Flyte makes operations idempotent)

```

#### Understanding @flyte.trace

**Why use it for both phases:**
- Submit phase checkpoint = "These jobs were submitted successfully"
- Wait phase checkpoint = "These results were retrieved successfully"
- Without it: A failure in submit or wait phase would re-submit or re-poll everything

**Best Practices:**
- Use `@flyte.trace` for non-deterministic operations (API calls, random operations)
- Don't use it for pure, deterministic functions (unnecessary overhead)
- Ensure traced functions are idempotent when possible
- Keep traced function signatures simple (serializable inputs/outputs)

See the [Traces](/docs/v2/byoc//user-guide/task-programming/traces/) docs for more details on how it works

### Step 6: Implement the Orchestrator Workflow

The orchestrator is the top-level task that:
1. Splits the total workload into batches
2. Launches all batches in parallel
3. Aggregates results from all batches
4. Reports overall statistics

**This is where the magic happens:** All batches run concurrently, limited only by your reusable container pool configuration.

```python
@orchestrator_env.task  # Runs in the orchestrator environment (no reuse)
async def microbatch_workflow(
    total_items: int = NUMBER_OF_INPUTS,
    batch_size: int = BATCH_SIZE,
) -> List[int]:
    """
    Main task orchestrating the entire micro-batching process.

    This task:
    1. Calculates optimal batch distribution
    2. Launches all batch tasks in parallel
    3. Aggregates results from completed batches
    4. Provides comprehensive execution statistics

    Args:
        total_items: Total number of items to process (default: 1M)
        batch_size: Number of items per batch (default: 1K)

    Returns:
        Aggregated results from all batches (list of processed values)

    Execution Flow:
        1M items → 1,000 batches → Parallel execution → Aggregated results

    Resource Usage:
        - This task: 4Gi memory, 1 CPU (orchestration only)
        - Each batch task: 2Gi memory, 1 CPU (from batch_env)
        - Reusable containers handle actual processing
    """

    # ========================================
    # STEP 1: CALCULATE BATCH DISTRIBUTION
    # ========================================
    # Split total items into batch ranges: [(0, 1000), (1000, 2000), ...]
    batches = [
        (start, min(start + batch_size, total_items))
        for start in range(0, total_items, batch_size)
    ]

    print(f"Processing {total_items:,} items in {len(batches):,} batches of size {batch_size:,}")
    print(f"Expected parallelism: {batch_env.reusable.replicas[0]}-{batch_env.reusable.replicas[1]} replicas")
    print(f"Concurrency per replica: {batch_env.reusable.concurrency}")
    print(f"Max simultaneous batches: {batch_env.reusable.replicas[1] * batch_env.reusable.concurrency}")

    # ========================================
    # STEP 2: LAUNCH ALL BATCHES IN PARALLEL
    # ========================================
    # This is the key to massive parallelism:
    # - Creates as many async tasks as concurrent operations your API supports
    # - All execute concurrently within container pool limits
    # - Reusable containers handle the workload efficiently
    # - return_exceptions=True prevents one batch failure from stopping all

    print(f"\n Launching {len(batches):,} parallel batch tasks...")

    # Rate limiter to control API throughput
    max_concurrent_batches = 10  # Adjust based on API rate limits
    semaphore = asyncio.Semaphore(max_concurrent_batches)

    async def rate_limited_batch(start: int, end: int):
        """Wrapper to enforce rate limiting on batch processing."""
        async with semaphore:
            return await process_batch(batch_start=start, batch_end=end)

    batch_results = await asyncio.gather(
        *(rate_limited_batch(start, end) for start, end in batches),
        return_exceptions=True  # Isolated failure handling per batch
    )
    # ========================================
    # STEP 3: AGGREGATE RESULTS & STATISTICS
    # ========================================
    all_results = []
    failed_batches = 0
    failed_items = 0

    for i, batch_result in enumerate(batch_results):
        if isinstance(batch_result, Exception):
            # Entire batch failed (task-level failure)
            print(f"[ERROR] Batch {i} failed completely: {batch_result}")
            failed_batches += 1
        else:
            # Batch completed, but individual items may have failed
            all_results.extend(batch_result)
            failed_items += len([r for r in batch_result if r == -1])

    # Calculate final statistics
    success_count = len([r for r in all_results if r != -1])
    total_processed = len(all_results)

    # ========================================
    # STEP 4: REPORT EXECUTION SUMMARY
    # ========================================
    print(f"\n{'=' * 60}")
    print(f" Execution summary")
    print(f"{'=' * 60}")
    print(f"Total items requested:    {total_items:,}")
    print(f"Total batches:            {len(batches):,}")
    print(f"Batch size:               {batch_size:,}")
    print(f"")
    print(f" Successful items:       {success_count:,}")
    print(f" Failed items:           {failed_items:,}")
    print(f" Failed batches:         {failed_batches}")
    print(f"")
    print(f" Success rate:           {success_count / total_items * 100:.2f}%")
    print(f" Items processed:        {total_processed:,} / {total_items:,}")
    print(f"{'=' * 60}\n")

    return all_results

# ========================================
# EXECUTION BEHAVIOR & OPTIMIZATION
# ========================================
#
# Parallel Execution Pattern:
# ┌─────────────────────────────────────────────────┐
# │ Orchestrator Task (1 pod, 4Gi, 1 CPU)         │
# │                                                 │
# │ Launches 1,000 process_batch() invocations     │
# └─────────────────┬───────────────────────────────┘
#                   │
#           ┌───────┴────────┐
#           ▼                ▼
#   ┌──────────────┐  ┌──────────────┐
#   │ Replica 1    │  │ Replica 2    │  ... up to 10 replicas
#   │ 2Gi, 1 CPU   │  │ 2Gi, 1 CPU   │
#   │              │  │              │
#   │ Concurrency: │  │ Concurrency: │
#   │ 5 batches    │  │ 5 batches    │
#   └──────────────┘  └──────────────┘
#
# With 10 replicas × 5 concurrency = 50 batches processing simultaneously
# Time to complete 1,000 batches ≈ 1,000 / 50 = 20 waves
#
# Optimization Tips:
# 1. Increase replicas for more parallelism (if cluster allows)
# 2. Adjust concurrency based on task I/O vs CPU profile
# 3. Tune batch_size to balance granularity vs overhead
# 4. Monitor actual execution to find bottlenecks
# 5. Use Flyte UI to visualize execution patterns
```

### Step 7: Execute the Workflow

Now let's run the entire workflow remotely on your Union cluster.

**Execution Options:**
- **Remote execution** (shown below): Runs on the Union cluster
- **Local execution**: Use `flyte.with_runcontext(mode="local").run()` for testing

**What happens during execution:**
1. Flyte builds the container image (if needed)
2. Creates the orchestrator pod
3. Orchestrator calculates batches and launches batch tasks
4. Reusable container pool starts spinning up (min: 3 replicas in this example)
5. Batches are distributed across available replicas
6. Pool scales up to max replicas (10 in this example) as needed
7. Results are aggregated and returned

```python
if __name__ == "__main__":
    print("=" * 60)
    print(" STARTING MICRO-BATCHING WORKFLOW")
    print("=" * 60)
    print(f"Total items to process: {NUMBER_OF_INPUTS:,}")
    print(f"Batch size: {BATCH_SIZE:,}")
    print(f"Expected batches: {NUMBER_OF_INPUTS // BATCH_SIZE:,}")
    print("=" * 60)
    print()

    # Launch the workflow remotely (runs on Flyte cluster)
    # The 'await' is needed because flyte.run.aio() is async
    r = await flyte.run.aio(microbatch_workflow)

    # Print execution details
    print(f"\n{'=' * 60}")
    print(f" EXECUTION STARTED")
    print(f"{'=' * 60}")
    # print(f"Run name: {r.name}")  # Internal run identifier
    print(f"🔗 Execution URL: {r.url}")
    print(f"\n💡 Visit the URL above to:")
    print(f"   • View the execution graph and task timeline")
    print(f"   • Monitor progress in real-time")
    print(f"   • See trace checkpoints in action")
    print(f"   • Inspect logs for each batch")
    print(f"   • Analyze resource utilization")
    print(f"{'=' * 60}\n")

# ========================================
# MONITORING AND DEBUGGING TIPS
# ========================================
#
# 1. View Execution in UI:
#    - Click the execution URL printed above
#    - See visual graph of all batch tasks
#    - Monitor which batches are running/completed/failed
#
# 2. Check Logs:
#    - Click on individual batch tasks in the graph
#    - View stdout/stderr for debugging
#    - See checkpoint/recovery messages
#
# 3. Resource Utilization:
#    - Navigate to Resources tab in UI
#    - Monitor CPU/memory usage per task
#    - Identify bottlenecks or over-provisioning
#
# 4. Trace Visualization:
#    - Expand batch tasks to see trace checkpoints
#    - Verify submit_phase and wait_phase separately
#    - Understand recovery points on failures
#
# 5. Performance Analysis:
#    - Check task durations in timeline view
#    - Identify slow batches or stragglers
#    - Optimize batch_size or concurrency based on results
```

On execution, this is what this example looks like at the Kubernetes level:

![](./images/reusable-containers-k8s.png)

This is, 10 replicas (as defined in the `TaskEnvironment`) and the driver Pod that runs the parent task (`a0`). [Learn more about the parent task](/docs/v2/byoc//user-guide/considerations/#driver-pod-requirements).

## Batch Size Selection

**Finding the optimal batch size:**
- **Too small:** More overhead from task management, less efficient
- **Too large:** Longer recovery time on failures, higher memory usage

**Factors to consider:**
- Item processing time (longer = larger batches)
- Memory consumption per item (higher = smaller batches)
- Failure tolerance (critical = smaller batches for faster recovery)
- Total workload size (larger total = can use larger batches)

Read the [Optimization strategies](/docs/v2/byoc//user-guide/run-scaling/scale-your-workflows/#2-batch-workloads-to-reduce-overhead) page to understand the overheads associated with an execution and how to choose the appropiate batch size.

## Summary

This notebook demonstrated a production-ready micro-batching pattern for Flyte v2 that combines:

1. **Reusable Containers** for efficiency
2. **@flyte.trace** for checkpointing and recovery
3. **Massive parallelism** via async/await
4. **Robust error handling** for resilience

**Key Takeaways:**
- Use `@flyte.trace` for non-deterministic operations
- Monitor resource usage and optimize incrementally
- Choose the right pattern for your specific use case

**Next Steps:**
- Adapt this pattern to your specific use case
- Replace mock functions with real API calls
- Test with your actual dataset
- Monitor and optimize based on production metrics

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/deep-research ===

# Deep research

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/deep_research_agent); based on work by [Together AI](https://github.com/togethercomputer/open_deep_research).

This example demonstrates how to build an agentic workflow for deep research—a multi-step reasoning system that mirrors how a human researcher explores, analyzes, and synthesizes information from the web.

Deep research refers to the iterative process of thoroughly investigating a topic: identifying relevant sources, evaluating their usefulness, refining the research direction, and ultimately producing a well-structured summary or report. It's a long-running task that requires the agent to reason over time, adapt its strategy, and chain multiple steps together, making it an ideal fit for an agentic architecture.

In this example, we use:

- [Tavily](https://www.tavily.com/) to search for and retrieve high-quality online resources.
- [LiteLLM](https://litellm.ai/) to route LLM calls that perform reasoning, evaluation, and synthesis.

The agent executes a multi-step trajectory:

- Parallel search across multiple queries.
- Evaluation of retrieved results.
- Adaptive iteration: If results are insufficient, it formulates new research queries and repeats the search-evaluate cycle.
- Synthesis: After a fixed number of iterations, it produces a comprehensive research report.

What makes this workflow compelling is its dynamic, evolving nature. The agent isn't just following a fixed plan; it's making decisions in context, using multiple prompts and reasoning steps to steer the process.

Flyte is uniquely well-suited for this kind of system. It provides:

- Structured composition of dynamic reasoning steps
- Built-in parallelism for faster search and evaluation
- Traceability and observability into each step and iteration
- Scalability for long-running or compute-intensive workloads

![Result](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/gifs/tutorials/deep-research/result.gif)

Throughout this guide, we'll show how to design this workflow using the Flyte SDK, and how to unlock the full potential of agentic development with tools you already know and trust.

## Setting up the environment

Let's begin by setting up the task environment. We define the following components:

- Secrets for Together and Tavily API keys
- A custom image with required Python packages and apt dependencies (`pandoc`, `texlive-xetex`)
- External YAML file with all LLM prompts baked into the container

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

The Python packages are declared at the top of the file using the `uv` script style:

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b6",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# ///
```

## Generate research queries

This task converts a user prompt into a list of focused queries. It makes two LLM calls to generate a high-level research plan and parse that plan into atomic search queries.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

LLM calls use LiteLLM, and each is wrapped with `flyte.trace` for observability:

```
from typing import Any, AsyncIterator, Optional

from litellm import acompletion, completion

import flyte

# {{docs-fragment asingle_shot_llm_call}}
@flyte.trace
async def asingle_shot_llm_call(
    model: str,
    system_prompt: str,
    message: str,
    response_format: Optional[dict[str, str | dict[str, Any]]] = None,
    max_completion_tokens: int | None = None,
) -> AsyncIterator[str]:
    stream = await acompletion(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": message},
        ],
        temperature=0.0,
        response_format=response_format,
        # NOTE: max_token is deprecated per OpenAI API docs, use max_completion_tokens instead if possible
        # NOTE: max_completion_tokens is not currently supported by Together AI, so we use max_tokens instead
        max_tokens=max_completion_tokens,
        timeout=600,
        stream=True,
    )
    async for chunk in stream:
        content = chunk.choices[0].delta.get("content", "")
        if content:
            yield content

# {{/docs-fragment asingle_shot_llm_call}}

def single_shot_llm_call(
    model: str,
    system_prompt: str,
    message: str,
    response_format: Optional[dict[str, str | dict[str, Any]]] = None,
    max_completion_tokens: int | None = None,
) -> str:
    response = completion(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": message},
        ],
        temperature=0.0,
        response_format=response_format,
        # NOTE: max_token is deprecated per OpenAI API docs, use max_completion_tokens instead if possible
        # NOTE: max_completion_tokens is not currently supported by Together AI, so we use max_tokens instead
        max_tokens=max_completion_tokens,
        timeout=600,
    )
    return response.choices[0].message["content"]  # type: ignore
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/libs/utils/llms.py*

> [!NOTE]
> We use `flyte.trace` to track intermediate steps within a task, like LLM calls or specific function executions. This lightweight decorator adds observability with minimal overhead and is especially useful for inspecting reasoning chains during task execution.

## Search and summarize

We submit each research query to Tavily and summarize the results using an LLM. We run all summarization tasks with `asyncio.gather`, which signals to Flyte that these tasks can be distributed across separate compute resources.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

## Evaluate research completeness

Now we assess whether the gathered research is sufficient. Again, the task uses two LLM calls to evaluate the completeness of the results and propose additional queries if necessary.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

## Filter results

In this step, we evaluate the relevance of search results and rank them. This task returns the most useful sources for the final synthesis.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

## Generate the final answer

Finally, we generate a detailed research report by synthesizing the top-ranked results. This is the output returned to the user.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

## Orchestration

Next, we define a `research_topic` task to orchestrate the entire deep research workflow. It runs the core stages in sequence: generating research queries, performing search and summarization, evaluating the completeness of results, and producing the final report.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

The `main` task wraps this entire pipeline and adds report generation in HTML format as the final step.
It also serves as the main entry point to the workflow, allowing us to pass in all configuration parameters, including which LLMs to use at each stage.
This flexibility lets us mix and match models for planning, summarization, and final synthesis, helping us optimize for both cost and quality.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "pydantic==2.11.5",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
#    "together==1.5.24",
#    "markdown==3.8.2",
#    "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///

# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path

import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
    DeepResearchResult,
    DeepResearchResults,
    ResearchPlan,
    SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results

TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096

logging = AgentLogger("together.open_deep_research")

env = flyte.TaskEnvironment(
    name="deep-researcher",
    secrets=[
        flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
        flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
    ],
    image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
    .with_apt_packages("pandoc", "texlive-xetex")
    .with_source_file(Path("prompts.yaml"), "/root"),
    resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}

# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
    topic: str,
    planning_model: str,
    json_model: str,
    prompts_file: File,
) -> list[str]:
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    PLANNING_PROMPT = prompts["planning_prompt"]

    plan = ""
    logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=PLANNING_PROMPT,
        message=f"Research Topic: {topic}",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        plan += chunk
        print(chunk, end="", flush=True)

    SEARCH_PROMPT = prompts["plan_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=SEARCH_PROMPT,
        message=f"Plan to be parsed: {plan}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    plan = json.loads(response_json)
    return plan["queries"]
# {{/docs-fragment generate_research_queries}}

async def _summarize_content_async(
    raw_content: str,
    query: str,
    prompt: str,
    summarization_model: str,
) -> str:
    """Summarize content asynchronously using the LLM"""
    logging.info("Summarizing content asynchronously using the LLM")

    result = ""
    async for chunk in asingle_shot_llm_call(
        model=summarization_model,
        system_prompt=prompt,
        message=f"<Raw Content>{raw_content}</Raw Content>\n\n<Research Topic>{query}</Research Topic>",
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        result += chunk
    return result

# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
    query: str,
    prompts_file: File,
    summarization_model: str,
) -> DeepResearchResults:
    """Perform search for a single query"""

    if len(query) > 400:
        # NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
        query = query[:400]
        logging.info(f"Truncated query to 400 characters: {query}")

    response = await atavily_search_results(query)

    logging.info("Tavily Search Called.")

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]

    with flyte.group("summarize-content"):
        # Create tasks for summarization
        summarization_tasks = []
        result_info = []
        for result in response.results:
            if result.raw_content is None:
                continue

            task = _summarize_content_async(
                result.raw_content,
                query,
                RAW_CONTENT_SUMMARIZER_PROMPT,
                summarization_model,
            )
            summarization_tasks.append(task)
            result_info.append(result)

        # Use return_exceptions=True to prevent exceptions from propagating
        summarized_contents = await asyncio.gather(
            *summarization_tasks, return_exceptions=True
        )

    # Filter out exceptions
    summarized_contents = [
        result for result in summarized_contents if not isinstance(result, Exception)
    ]

    formatted_results = []
    for result, summarized_content in zip(result_info, summarized_contents):
        formatted_results.append(
            DeepResearchResult(
                title=result.title,
                link=result.link,
                content=result.content,
                raw_content=result.raw_content,
                filtered_raw_content=summarized_content,
            )
        )
    return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}

@env.task
async def search_all_queries(
    queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
    """Execute searches for all queries in parallel"""
    tasks = []
    results_list = []

    tasks = [
        search_and_summarize(query, prompts_file, summarization_model)
        for query in queries
    ]

    if tasks:
        res_list = await asyncio.gather(*tasks)

    results_list.extend(res_list)

    # Combine all results
    combined_results = DeepResearchResults(results=[])
    for results in results_list:
        combined_results = combined_results + results

    return combined_results

# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
    topic: str,
    results: DeepResearchResults,
    queries: list[str],
    prompts_file: File,
    planning_model: str,
    json_model: str,
) -> list[str]:
    """
    Evaluate if the current search results are sufficient or if more research is needed.
    Returns an empty list if research is complete, or a list of additional queries if more research is needed.
    """

    # Format the search results for the LLM
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)

    EVALUATION_PROMPT = prompts["evaluation_prompt"]

    logging.info("\nEvaluation: ")
    evaluation = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=EVALUATION_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Search Queries Used>{queries}</Search Queries Used>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=None,
    ):
        evaluation += chunk
        print(chunk, end="", flush=True)

    EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=EVALUATION_PARSING_PROMPT,
        message=f"Evaluation to be parsed: {evaluation}",
        response_format={
            "type": "json_object",
            "schema": ResearchPlan.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    evaluation = json.loads(response_json)
    return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}

# {{docs-fragment filter_results}}
@env.task
async def filter_results(
    topic: str,
    results: DeepResearchResults,
    prompts_file: File,
    planning_model: str,
    json_model: str,
    max_sources: int,
) -> DeepResearchResults:
    """Filter the search results based on the research plan"""

    # Format the search results for the LLM, without the raw content
    formatted_results = str(results)

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    FILTER_PROMPT = prompts["filter_prompt"]

    logging.info("\nFilter response: ")
    filter_response = ""
    async for chunk in asingle_shot_llm_call(
        model=planning_model,
        system_prompt=FILTER_PROMPT,
        message=(
            f"<Research Topic>{topic}</Research Topic>\n\n"
            f"<Current Search Results>{formatted_results}</Current Search Results>"
        ),
        response_format=None,
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        filter_response += chunk
        print(chunk, end="", flush=True)

    logging.info(f"Filter response: {filter_response}")

    FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]

    response_json = ""
    async for chunk in asingle_shot_llm_call(
        model=json_model,
        system_prompt=FILTER_PARSING_PROMPT,
        message=f"Filter response to be parsed: {filter_response}",
        response_format={
            "type": "json_object",
            "schema": SourceList.model_json_schema(),
        },
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        response_json += chunk

    sources = json.loads(response_json)["sources"]

    logging.info(f"Filtered sources: {sources}")

    if max_sources != -1:
        sources = sources[:max_sources]

    # Filter the results based on the source list
    filtered_results = [
        results.results[i - 1] for i in sources if i - 1 < len(results.results)
    ]

    return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}

def _remove_thinking_tags(answer: str) -> str:
    """Remove content within <think> tags"""
    while "<think>" in answer and "</think>" in answer:
        start = answer.find("<think>")
        end = answer.find("</think>") + len("</think>")
        answer = answer[:start] + answer[end:]
    return answer

# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
    topic: str,
    results: DeepResearchResults,
    remove_thinking_tags: bool,
    prompts_file: File,
    answer_model: str,
) -> str:
    """
    Generate a comprehensive answer to the research topic based on the search results.
    Returns a detailed response that synthesizes information from all search results.
    """

    formatted_results = str(results)
    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    prompts = yaml.safe_load(yaml_contents)
    ANSWER_PROMPT = prompts["answer_prompt"]

    answer = ""
    async for chunk in asingle_shot_llm_call(
        model=answer_model,
        system_prompt=ANSWER_PROMPT,
        message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
        response_format=None,
        # NOTE: This is the max_token parameter for the LLM call on Together AI,
        # may need to be changed for other providers
        max_completion_tokens=MAX_COMPLETION_TOKENS,
    ):
        answer += chunk

    # this is just to avoid typing complaints
    if answer is None or not isinstance(answer, str):
        logging.error("No answer generated")
        return "No answer generated"

    if remove_thinking_tags:
        # Remove content within <think> tags
        answer = _remove_thinking_tags(answer)

    # Remove markdown code block markers if they exist at the beginning
    if answer.lstrip().startswith("```"):
        # Find the first line break after the opening backticks
        first_linebreak = answer.find("\n", answer.find("```"))
        if first_linebreak != -1:
            # Remove everything up to and including the first line break
            answer = answer[first_linebreak + 1 :]

        # Remove closing code block if it exists
        if answer.rstrip().endswith("```"):
            answer = answer.rstrip()[:-3].rstrip()

    return answer.strip()
# {{/docs-fragment generate_research_answer}}

# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
    topic: str,
    budget: int = 3,
    remove_thinking_tags: bool = True,
    max_queries: int = 5,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 40,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
    prompts_file: File | str = "prompts.yaml",
) -> str:
    """Main method to conduct research on a topic. Will be used for weave evals."""
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    # Step 1: Generate initial queries
    queries = await generate_research_queries(
        topic=topic,
        planning_model=planning_model,
        json_model=json_model,
        prompts_file=prompts_file,
    )
    queries = [topic, *queries[: max_queries - 1]]
    all_queries = queries.copy()
    logging.info(f"Initial queries: {queries}")

    if len(queries) == 0:
        logging.error("No initial queries generated")
        return "No initial queries generated"

    # Step 2: Perform initial search
    results = await search_all_queries(queries, summarization_model, prompts_file)
    logging.info(f"Initial search complete, found {len(results.results)} results")

    # Step 3: Conduct iterative research within budget
    for iteration in range(budget):
        with flyte.group(f"eval_iteration_{iteration}"):
            # Evaluate if more research is needed
            additional_queries = await evaluate_research_completeness(
                topic=topic,
                results=results,
                queries=all_queries,
                prompts_file=prompts_file,
                planning_model=planning_model,
                json_model=json_model,
            )

            # Filter out empty strings and check if any queries remain
            additional_queries = [q for q in additional_queries if q]
            if not additional_queries:
                logging.info("No need for additional research")
                break

            # for debugging purposes we limit the number of queries
            additional_queries = additional_queries[:max_queries]
            logging.info(f"Additional queries: {additional_queries}")

            # Expand research with new queries
            new_results = await search_all_queries(
                additional_queries, summarization_model, prompts_file
            )
            logging.info(
                f"Follow-up search complete, found {len(new_results.results)} results"
            )

            results = results + new_results
            all_queries.extend(additional_queries)

    # Step 4: Generate final answer
    logging.info(f"Generating final answer for topic: {topic}")
    results = results.dedup()
    logging.info(f"Deduplication complete, kept {len(results.results)} results")
    filtered_results = await filter_results(
        topic=topic,
        results=results,
        prompts_file=prompts_file,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
    )
    logging.info(
        f"LLM Filtering complete, kept {len(filtered_results.results)} results"
    )

    # Generate final answer
    answer = await generate_research_answer(
        topic=topic,
        results=filtered_results,
        remove_thinking_tags=remove_thinking_tags,
        prompts_file=prompts_file,
        answer_model=answer_model,
    )

    return answer
# {{/docs-fragment research_topic}}

# {{docs-fragment main}}
@env.task(report=True)
async def main(
    topic: str = (
        "List the essential requirements for a developer-focused agent orchestration system."
    ),
    prompts_file: File | str = "/root/prompts.yaml",
    budget: int = 2,
    remove_thinking_tags: bool = True,
    max_queries: int = 3,
    answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
    planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
    json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
    max_sources: int = 10,
    summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
    if isinstance(prompts_file, str):
        prompts_file = await File.from_local(prompts_file)

    answer = await research_topic(
        topic=topic,
        budget=budget,
        remove_thinking_tags=remove_thinking_tags,
        max_queries=max_queries,
        answer_model=answer_model,
        planning_model=planning_model,
        json_model=json_model,
        max_sources=max_sources,
        summarization_model=summarization_model,
        prompts_file=prompts_file,
    )

    async with prompts_file.open() as fh:
        data = await fh.read()
        yaml_contents = str(data, "utf-8")

    toc_image_url = await generate_toc_image(
        yaml.safe_load(yaml_contents)["data_visualization_prompt"],
        planning_model,
        topic,
    )

    html_content = await generate_html(answer, toc_image_url)
    await flyte.report.replace.aio(html_content, do_flush=True)
    await flyte.report.flush.aio()

    return html_content
# {{/docs-fragment main}}

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py*

## Run the deep research agent

First, create the required secrets:

```
flyte create secret TOGETHER_API_KEY <>
flyte create secret TAVILY_API_KEY <>
```

Run the agent:

```
uv run agent.py
```

If you want to test it locally first, run the following commands:

```
brew install pandoc
brew install basictex # restart your terminal after install

export TOGETHER_API_KEY=<>
export TAVILY_API_KEY=<>

uv run agent.py
```

## Evaluate with Weights & Biases Weave

We use W&B Weave to evaluate the full agent pipeline and analyze LLM-generated responses. The evaluation runs as a Flyte pipeline and uses an LLM-as-a-judge scorer to measure the quality of LLM-generated responses.

```
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "weave==0.51.51",
#    "datasets==3.6.0",
#    "huggingface-hub==0.32.6",
#    "litellm==1.72.2",
#    "tavily-python==0.7.5",
# ]
# ///

import os

import weave
from agent import research_topic
from datasets import load_dataset
from huggingface_hub import login
from libs.utils.log import AgentLogger
from litellm import completion

import flyte

logging = AgentLogger()

weave.init(project_name="deep-researcher")

env = flyte.TaskEnvironment(name="deep-researcher-eval")

@weave.op
def llm_as_a_judge_scoring(answer: str, output: str, question: str) -> bool:
    prompt = f"""
    Given the following question and answer, evaluate the answer against the correct answer:

    <question>
    {question}
    </question>

    <agent_answer>
    {output}
    </agent_answer>

    <correct_answer>
    {answer}
    </correct_answer>

    Note that the agent answer might be a long text containing a lot of information or it might be a short answer.

    You should read the entire text and think if the agent answers the question somewhere
    in the text. You should try to be flexible with the answer but careful.

    For example, answering with names instead of name and surname is fine.

    The important thing is that the answer of the agent either contains the correct answer or is equal to
    the correct answer.

    <reasoning>
    The agent answer is correct because I can read that ....
    </reasoning>

    <answer>
    1
    </answer>

    Otherwise, return

    <reasoning>
    The agent answer is incorrect because there is ...
    </reasoning>

    <answer>
    0
    </answer>

    """

    messages = [
        {
            "role": "system",
            "content": "You are an helpful assistant that returns a number between 0 and 1.",
        },
        {"role": "user", "content": prompt},
    ]
    answer = (
        completion(
            model="together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
            messages=messages,
            max_tokens=1000,
            temperature=0.0,
        )
        .choices[0]  # type: ignore
        .message["content"]  # type: ignore
    )

    return bool(int(answer.split("<answer>")[1].split("</answer>")[0].strip()))

def authenticate_huggingface():
    """Authenticate with Hugging Face Hub using token from environment variable."""
    token = os.getenv("HUGGINGFACE_TOKEN")
    if not token:
        raise ValueError(
            "HUGGINGFACE_TOKEN environment variable not set. "
            "Please set it with your token from https://huggingface.co/settings/tokens"
        )

    try:
        login(token=token)
        print("Successfully authenticated with Hugging Face Hub")
    except Exception as e:
        raise RuntimeError(f"Failed to authenticate with Hugging Face Hub: {e!s}")

@env.task
async def load_questions(
    dataset_names: list[str] | None = None,
) -> list[dict[str, str]]:
    """
    Load questions from the specified Hugging Face dataset configurations.

    Args:
        dataset_names: List of dataset configurations to load
                      Options:
                          "smolagents:simpleqa",
                          "hotpotqa",
                          "simpleqa",
                          "together-search-bench"
                      If None, all available configurations except hotpotqa will be loaded

    Returns:
        List of question-answer pairs
    """
    if dataset_names is None:
        dataset_names = ["smolagents:simpleqa"]

    all_questions = []

    # Authenticate with Hugging Face Hub (once and for all)
    authenticate_huggingface()

    for dataset_name in dataset_names:
        print(f"Loading dataset: {dataset_name}")

        try:
            if dataset_name == "together-search-bench":
                # Load Together-Search-Bench dataset
                dataset_path = "togethercomputer/together-search-bench"
                ds = load_dataset(dataset_path)
                if "test" in ds:
                    split_data = ds["test"]
                else:
                    print(f"No 'test' split found in dataset at {dataset_path}")
                    continue

                for i in range(len(split_data)):
                    item = split_data[i]
                    question_data = {
                        "question": item["question"],
                        "answer": item["answer"],
                        "dataset": item.get("dataset", "together-search-bench"),
                    }
                    all_questions.append(question_data)

                print(f"Loaded {len(split_data)} questions from together-search-bench dataset")
                continue

            elif dataset_name == "hotpotqa":
                # Load HotpotQA dataset (using distractor version for validation)
                ds = load_dataset("hotpotqa/hotpot_qa", "distractor", trust_remote_code=True)
                split_name = "validation"
            elif dataset_name == "simpleqa":
                ds = load_dataset("basicv8vc/SimpleQA")
                split_name = "test"
            else:
                # Strip "smolagents:" prefix when loading the dataset
                actual_dataset = dataset_name.split(":")[-1]
                ds = load_dataset("smolagents/benchmark-v1", actual_dataset)
                split_name = "test"

        except Exception as e:
            print(f"Failed to load dataset {dataset_name}: {e!s}")
            continue  # Skip this dataset if it fails to load

        print(f"Dataset structure for {dataset_name}: {ds}")
        print(f"Available splits: {list(ds)}")

        split_data = ds[split_name]  # type: ignore

        for i in range(len(split_data)):
            item = split_data[i]

            if dataset_name == "hotpotqa":
                # we remove questions that are easy or medium (if any) just to reduce the number of questions
                if item["level"] != "hard":
                    continue

                question_data = {
                    "question": item["question"],
                    "answer": item["answer"],
                    "dataset": dataset_name,
                }
            elif dataset_name == "simpleqa":
                # Handle SimpleQA dataset format
                question_data = {
                    "question": item["problem"],
                    "answer": item["answer"],
                    "dataset": dataset_name,
                }
            else:
                question_data = {
                    "question": item["question"],
                    "answer": item["true_answer"],
                    "dataset": dataset_name,
                }

            all_questions.append(question_data)

    print(f"Loaded {len(all_questions)} questions in total")
    return all_questions

@weave.op
async def predict(question: str):
    return await research_topic(topic=str(question))

@env.task
async def main(datasets: list[str] = ["together-search-bench"], limit: int | None = 1):
    questions = await load_questions(datasets)

    if limit is not None:
        questions = questions[:limit]
        print(f"Limited to {len(questions)} question(s)")

    evaluation = weave.Evaluation(dataset=questions, scorers=[llm_as_a_judge_scoring])
    await evaluation.evaluate(predict)

if __name__ == "__main__":
    flyte.init_from_config()
    flyte.with_runcontext(raw_data_path="data").run(main)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/weave_evals.py*

You can run this pipeline locally as follows:

```
export HUGGINGFACE_TOKEN=<> # https://huggingface.co/settings/tokens
export WANDB_API_KEY=<> # https://wandb.ai/settings

uv run weave_evals.py
```

The script will run all tasks in the pipeline and log the evaluation results to Weights & Biases.
While you can also evaluate individual tasks, this script focuses on end-to-end evaluation of the end-to-end deep research workflow.

![Weave evaluations](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/deep-research/weave_evals.png)

=== PAGE: https://www.union.ai/docs/v2/union/tutorials/hpo ===

# Hyperparameter optimization

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/ml/optimizer.py).

Hyperparameter Optimization (HPO) is a critical step in the machine learning (ML) lifecycle. Hyperparameters are the knobs and dials of a model—values such as learning rates, tree depths, or dropout rates that significantly impact performance but cannot be learned during training. Instead, we must select them manually or optimize them through guided search.

Model developers often enjoy the flexibility of choosing from a wide variety of model types, whether gradient boosted machines (GBMs), generalized linear models (GLMs), deep learning architectures, or dozens of others. A common challenge across all these options is the need to systematically explore model performance across hyperparameter configurations tailored to the specific dataset and task.

Thankfully, this exploration can be automated. Frameworks like [Optuna](https://optuna.org/), [Hyperopt](https://hyperopt.github.io/hyperopt/), and [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) use advanced sampling algorithms to efficiently search the hyperparameter space and identify optimal configurations. HPO may be executed in two distinct ways:

- **Serial HPO** runs one trial at a time, which is easy to set up but can be painfully slow.
- **Parallel HPO** distributes trials across multiple processes. It typically follows a pattern with two parameters: **_N_**, the total number of trials to run, and **_C_**, the maximum number of trials that can run concurrently. Trials are executed asynchronously, and new ones are scheduled based on the results and status of completed or in-progress ones.

However, parallel HPO introduces a new complexity: the need for a centralized state that tracks:

- All past trials (successes and failures)
- All ongoing trials

This state is essential so that the optimization algorithm can make informed decisions about which hyperparameters to try next.

## A better way to run HPO

This is where Flyte shines.

- There's no need to manage a separate centralized database for state tracking, as every objective run is **cached**, **recorded**, and **recoverable** via Flyte's execution engine.
- The entire HPO process is observable in the UI with full lineage and metadata for each trial.
- Each objective is seeded for reproducibility, enabling deterministic trial results.
- If the main optimization task crashes or is terminated, **Flyte can resume from the last successful or failed trial, making the experiment highly fault-tolerant**.
- Trial functions can be strongly typed, enabling rich, flexible hyperparameter spaces while maintaining strict type safety across trials.

In this example, we combine Flyte with Optuna to optimize a `RandomForestClassifier` on the Iris dataset. Each trial runs in an isolated task, and the optimization process is orchestrated asynchronously, with Flyte handling the underlying scheduling, retries, and caching.

## Declare dependencies

We start by declaring a Python environment using Python 3.13 and specifying our runtime dependencies.

```
# /// script
requires-python = "==3.13"
dependencies = [
   "optuna>=4.0.0,<5.0.0",
   "flyte>=2.0.0b0",
   "scikit-learn==1.7.0",
]
# ///
```

With the environment defined, we begin by importing standard library and third-party modules necessary for both the ML task and distributed execution.

```
import asyncio
import typing
from collections import Counter
from typing import Optional, Union
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

These standard library imports are essential for asynchronous execution (`asyncio`), type annotations (`typing`, `Optional`, `Union`), and aggregating trial state counts (`Counter`).

```
import optuna
from optuna import Trial
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.utils import shuffle
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

We use Optuna for hyperparameter optimization and several utilities from scikit-learn to prepare data (`load_iris`), define the model (`RandomForestClassifier`), evaluate it (`cross_val_score`), and shuffle the dataset for randomness (`shuffle`).

```
import flyte
import flyte.errors
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

Flyte is our orchestration framework. We use it to define tasks, manage resources, and recover from execution errors.

## Define the task environment

We define a Flyte task environment called `driver`, which encapsulates metadata, compute resources, the container image context needed for remote execution, and caching behavior.

```
driver = flyte.TaskEnvironment(
    name="driver",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="optimizer"),
    cache="auto",
)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

This environment specifies that the tasks will run with 1 CPU and 250Mi of memory, the image is built using the current script (`__file__`), and caching is enabled.

<p>
  You can configure the Flyte task environment to reuse containers across multiple executions by setting the
  <code>reusable</code> field to
  <code>flyte.ReusePolicy(replicas=..., idle_ttl=...)</code>. This is especially useful when the final objective
  computations are short-lived, as it avoids unnecessary container spin-up costs. Learn more about reusable containers
  <a href="../../user-guide/reusable-containers/">here</a>.
</p>

## Define the optimizer

Next, we define an `Optimizer` class that handles parallel execution of Optuna trials using async coroutines. This class abstracts the full optimization loop and supports concurrent trial execution with live logging.

```
class Optimizer:
    def __init__(
        self,
        objective: callable,
        n_trials: int,
        concurrency: int = 1,
        delay: float = 0.1,
        study: Optional[optuna.Study] = None,
        log_delay: float = 0.1,
    ):
        self.n_trials: int = n_trials
        self.concurrency: int = concurrency
        self.objective: typing.Callable = objective
        self.delay: float = delay
        self.log_delay = log_delay

        self.study = study if study else optuna.create_study()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

We pass the `objective` function, number of trials to run (`n_trials`), and maximum parallel trials (`concurrency`). The optional delay throttles execution between trials, while `log_delay` controls how often logging runs. If no existing Optuna Study is provided, a new one is created automatically.

```
    async def log(self):
        while True:
            await asyncio.sleep(self.log_delay)

            counter = Counter()

            for trial in self.study.trials:
                counter[trial.state.name.lower()] += 1

            counts = dict(counter, queued=self.n_trials - len(self))

            # print items in dictionary in a readable format
            formatted = [f"{name}: {count}" for name, count in counts.items()]
            print(f"{'    '.join(formatted)}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

This method periodically prints the number of trials in each state (e.g., running, complete, fail). It keeps users informed of ongoing optimization progress and is invoked as a background task when logging is enabled.

![Optuna logging](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/hpo/logging.png)
_Logs are streamed live as the execution progresses._

```
    async def spawn(self, semaphore: asyncio.Semaphore):
        async with semaphore:
            trial: Trial = self.study.ask()

            try:
                print("Starting trial", trial.number)

                params = {
                    "n_estimators": trial.suggest_int("n_estimators", 10, 200),
                    "max_depth": trial.suggest_int("max_depth", 2, 20),
                    "min_samples_split": trial.suggest_float(
                        "min_samples_split", 0.1, 1.0
                    ),
                }

                output = await self.objective(params)

                self.study.tell(trial, output, state=optuna.trial.TrialState.COMPLETE)
            except flyte.errors.RuntimeUserError as e:
                print(f"Trial {trial.number} failed: {e}")

                self.study.tell(trial, state=optuna.trial.TrialState.FAIL)

            await asyncio.sleep(self.delay)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

Each call to `spawn` runs a single Optuna trial. The `semaphore` ensures that only a fixed number of concurrent trials are active at once, respecting the `concurrency` parameter. We first ask Optuna for a new trial and generate a parameter dictionary by querying the trial object for suggested hyperparameters. The trial is then evaluated by the objective function. If successful, we mark it as `COMPLETE`. If the trial fails due to a `RuntimeUserError` from Flyte, we log and record the failure in the Optuna study.

```
    async def __call__(self):
        # create semaphore to manage concurrency
        semaphore = asyncio.Semaphore(self.concurrency)

        # create list of async trials
        trials = [self.spawn(semaphore) for _ in range(self.n_trials)]

        logger: Optional[asyncio.Task] = None
        if self.log_delay:
            logger = asyncio.create_task(self.log())

        # await all trials to complete
        await asyncio.gather(*trials)

        if self.log_delay and logger:
            logger.cancel()
            try:
                await logger
            except asyncio.CancelledError:
                pass
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

The `__call__` method defines the overall async optimization routine. It creates the semaphore, spawns `n_trials` coroutines, and optionally starts the background logging task. All trials are awaited with `asyncio.gather`.

```
    def __len__(self) -> int:
        """Return the number of trials in history."""
        return len(self.study.trials)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

This method simply allows us to query the number of trials already associated with the study.

## Define the objective function

The objective task defines how we evaluate a particular set of hyperparameters. It's an async task, allowing for caching, tracking, and recoverability across executions.

```
@driver.task
async def objective(params: dict[str, Union[int, float]]) -> float:
    data = load_iris()
    X, y = shuffle(data.data, data.target, random_state=42)

    clf = RandomForestClassifier(
        n_estimators=params["n_estimators"],
        max_depth=params["max_depth"],
        min_samples_split=params["min_samples_split"],
        random_state=42,
        n_jobs=-1,
    )

    # Use cross-validation to evaluate performance
    score = cross_val_score(clf, X, y, cv=3, scoring="accuracy").mean()

    return score.item()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

We use the Iris dataset as a toy classification problem. The input params dictionary contains the trial's hyperparameters, which we unpack into a `RandomForestClassifier`. We shuffle the dataset for randomness, and compute a 3-fold cross-validation accuracy.

## Define the main optimization loop

The optimize task is the main driver of our optimization experiment. It creates the `Optimizer` instance and invokes it.

```
@driver.task
async def optimize(
    n_trials: int = 20,
    concurrency: int = 5,
    delay: float = 0.05,
    log_delay: float = 0.1,
) -> dict[str, Union[int, float]]:
    optimizer = Optimizer(
        objective=objective,
        n_trials=n_trials,
        concurrency=concurrency,
        delay=delay,
        log_delay=log_delay,
        study=optuna.create_study(
            direction="maximize", sampler=optuna.samplers.TPESampler(seed=42)
        ),
    )

    await optimizer()

    best = optimizer.study.best_trial

    print("✅ Best Trial")
    print("  Number :", best.number)
    print("  Params :", best.params)
    print("  Score  :", best.value)

    return best.params
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

We configure a `TPESampler` for Optuna and `seed` it for determinism. After running all trials, we extract the best-performing trial and print its parameters and score. Returning the best params allows downstream tasks or clients to use the tuned model.

## Run the experiment

Finally, we include an executable entry point to run this optimization using `flyte.run`.

```
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(optimize, 100, 10)
    print(run.url)
    run.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py*

We load Flyte config from `config.yaml`, launch the optimize task with 100 trials and concurrency of 10, and print a link to view the execution in the Flyte UI.

![HPO execution](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/tutorials/hpo/execution.png)
_Each objective run is cached, recorded, and recoverable. With concurrency set to 10, only 10 trials execute in parallel at any given time._

