# Integrations
> This bundle contains all pages in the Integrations section.
> Source: https://www.union.ai/docs/v2/union/integrations/

=== PAGE: https://www.union.ai/docs/v2/union/integrations ===

# Integrations

> **📝 Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.

Flyte 2 is designed to be extensible by default. While the core platform covers the most common orchestration needs, many production workloads require specialized infrastructure, external services or execution semantics that go beyond the core runtime.

Flyte 2 exposes these capabilities through integrations.

Under the hood, integrations are implemented using Flyte 2's plugin system, which provides a consistent way to extend the platform without modifying core execution logic.

An integration allows you to declaratively enable new capabilities such as distributed compute frameworks or third-party services without manually managing infrastructure. You specify what you need, and Flyte takes care of how it is provisioned, used and cleaned up.

This page covers:

- The types of integrations Flyte 2 supports today
- How integrations fit into Flyte 2's execution model
- How to use integrations in your tasks
- The integrations available out of the box

If you need functionality that doesn't exist yet, Flyte 2's plugin system is intentionally open-ended. You can build and register your own integrations using the same architecture described here.

## Integration categories

Flyte 2 integrations fall into the following categories:

1. **Distributed compute**: Provision transient compute clusters to run tasks across multiple nodes, with automatic lifecycle management.
2. **Agentic AI**: Support for various common aspects of agentic AI applications.
3. **Experiment tracking**: Integrate with experiment tracking platforms for logging metrics, parameters, and artifacts.
4. **Connectors**: Stateless, long-running services that receive execution requests via gRPC and then submit work to external (or internal) systems.
5. **LLM Serving**: Deploy and serve large language models with an OpenAI-compatible API.

## Distributed compute

Distributed compute integrations allow tasks to run on dynamically provisioned clusters. These clusters are created just-in-time, scoped to the task execution and torn down automatically when the task completes.

This enables large-scale parallelism without requiring users to operate or maintain long-running infrastructure.

### Supported distributed compute integrations

| Plugin               | Description                                      | Common use cases                                       |
| -------------------- | ------------------------------------------------ | ------------------------------------------------------ |
| [Ray](./ray/_index)         | Provisions Ray clusters via KubeRay              | Distributed Python, ML training, hyperparameter tuning |
| [Spark](./spark/_index)     | Provisions Spark clusters via Spark Operator     | Large-scale data processing, ETL pipelines             |
| [Dask](./dask/_index)       | Provisions Dask clusters via Dask Operator       | Parallel Python workloads, dataframe operations        |
| [PyTorch](./pytorch/_index) | Distributed PyTorch training with elastic launch | Single-node and multi-node training                    |

Each plugin encapsulates:

- Cluster provisioning
- Resource configuration
- Networking and service discovery
- Lifecycle management and teardown

From the task author's perspective, these details are abstracted away.

### How the plugin system works

At a high level, Flyte 2's distributed compute plugin architecture follows a simple and consistent pattern.

#### 1. Registration

Each plugin registers itself with Flyte 2's core plugin registry:

- **`TaskPluginRegistry`**: The central registry for all distributed compute plugins
- Each plugin declares:
  - Its configuration schema
  - How that configuration maps to execution behavior

This registration step makes the plugin discoverable by the runtime.

#### 2. Task environments and plugin configuration

Integrations are activated through a `TaskEnvironment`.

A `TaskEnvironment` bundles:

- A container image
- Execution settings
- A plugin configuration object enabled with `plugin_config`

The plugin configuration describes _what_ infrastructure or integration the task requires.

#### 3. Automatic provisioning and execution

When a task associated with a `TaskEnvironment` runs:

1. Flyte inspects the environment's plugin configuration
2. The plugin provisions the required infrastructure or integration
3. The task executes with access to that capability
4. Flyte cleans up all transient resources after completion

### Example: Using the Dask plugin

Below is a complete example showing how a task gains access to a Dask cluster simply by running inside an environment configured with the Dask plugin.

```python
from flyteplugins.dask import Dask, WorkerGroup
import flyte

# Define the Dask cluster configuration
dask_config = Dask(
    workers=WorkerGroup(number_of_workers=4)
)

# Create a task environment that enables Dask
env = flyte.TaskEnvironment(
    name="dask_env",
    plugin_config=dask_config,
    image=image,
)

# Any task in this environment has access to the Dask cluster
@env.task
async def process_data(data: list) -> list:
    from distributed import Client

    client = Client()  # Automatically connects to the provisioned cluster
    futures = client.map(transform, data)
    return client.gather(futures)
```

When `process_data` executes, Flyte performs the following steps:

1. Provisions a Dask cluster with 4 workers
2. Executes the task with network access to the cluster
3. Tears down the cluster once the task completes

No cluster management logic appears in the task code. The task only expresses intent.

### Key design principle

All distributed compute integrations follow the same mental model:

- You declare the required capability via configuration
- You attach that configuration to a task environment
- Tasks decorated with that environment automatically gain access to the capability

This makes it easy to swap execution backends or introduce distributed compute incrementally without rewriting workflows.

## Agentic AI

Agentic AI integrations provide drop-in replacements for LLM provider SDKs. They let you use Flyte tasks as agent tools so that tool calls run with full Flyte observability, retries, and caching.

### Supported agentic AI integrations

| Plugin                              | Description                                                  | Common use cases                     |
| ----------------------------------- | ------------------------------------------------------------ | ------------------------------------ |
| [OpenAI](./openai/_index)           | Drop-in replacement for OpenAI Agents SDK `function_tool`    | Agentic workflows with OpenAI models |
| [Anthropic](./anthropic/_index)     | Agent loop and `function_tool` for the Anthropic Claude SDK  | Agentic workflows with Claude        |
| [Gemini](./gemini/_index)           | Agent loop and `function_tool` for the Google Gemini SDK     | Agentic workflows with Gemini        |
| [Code generation](./codegen/_index) | LLM-driven code generation with automatic testing in sandboxes | Data processing, ETL, analysis pipelines |

## Experiment tracking

Experiment tracking integrations let you log metrics, parameters, and artifacts to external tracking platforms during Flyte task execution.

### Supported experiment tracking integrations

| Plugin                               | Description                  | Common use cases                              |
| ------------------------------------ | ---------------------------- | --------------------------------------------- |
| [MLflow](./mlflow/_index)            | MLflow experiment tracking   | Experiment tracking, autologging, model registry |
| [Weights and Biases](./wandb/_index) | Weights & Biases integration | Experiment tracking and hyperparameter tuning |

## Connectors

Connectors are stateless, long‑running services that receive execution requests via gRPC and then submit work to external (or internal) systems. Each connector runs as its own Kubernetes deployment, and is triggered when a Flyte task of the matching type is executed.

Although they normally run inside the control plane, you can also run connectors locally as long as the required secrets/credentials are present locally. This is useful because connectors are just Python services that can be spawned in‑process.

Connectors are designed to scale horizontally and reduce load on the core Flyte backend because they execute _outside_ the core system. This decoupling makes connectors efficient, resilient, and easy to iterate on. You can even test them locally without modifying backend configuration, which reduces friction during development.

### Supported connectors

| Connector                          | Description                                    | Common use cases                         |
| ---------------------------------- | ---------------------------------------------- | ---------------------------------------- |
| [Snowflake](./snowflake/_index)    | Run SQL queries on Snowflake asynchronously    | Data warehousing, ETL, analytics queries |
| [BigQuery](./bigquery/_index)      | Run SQL queries on Google BigQuery             | Data warehousing, ETL, analytics queries |
| [Databricks](./databricks/_index)  | Run PySpark jobs on Databricks clusters        | Large-scale data processing, Spark ETL   |

### Creating a new connector

If none of the existing connectors meet your needs, you can build your own.

> [!NOTE]
> Connectors communicate via Protobuf, so in theory they can be implemented in any language.
> Today, only **Python** connectors are supported.

### Async connector interface

To implement a new async connector, extend `AsyncConnector` and implement the following methods, all of which must be idempotent:

| Method   | Purpose                                                     |
| -------- | ----------------------------------------------------------- |
| `create` | Launch the external job (via REST, gRPC, SDK, or other API) |
| `get`    | Fetch current job state (return job status or output)       |
| `delete` | Delete / cancel the external job                            |

To test the connector locally, the connector task should inherit from
[AsyncConnectorExecutorMixin](https://github.com/flyteorg/flyte-sdk/blob/1d49299294cd5e15385fe8c48089b3454b7a4cd1/src/flyte/connectors/_connector.py#L206). This mixin simulates how the Flyte 2 system executes asynchronous connector tasks, making it easier to validate your connector implementation before deploying it.

### Example: Model training connector

The following example implements a connector that launches a model training job on an external training service.

```python
import typing
from dataclasses import dataclass

import httpx
from flyte.connectors import AsyncConnector, Resource, ResourceMeta
from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
from flyteidl2.core.tasks_pb2 import TaskTemplate
from google.protobuf import json_format

@dataclass
class ModelTrainJobMeta(ResourceMeta):
    job_id: str
    endpoint: str

class ModelTrainingConnector(AsyncConnector):
    """
    Example connector that launches a ML model training job on an external training service.

    POST → launch training job
    GET  → poll training progress
    DELETE → cancel training job
    """

    name = "Model Training Connector"
    task_type_name = "external_model_training"
    metadata_type = ModelTrainJobMeta

    async def create(
        self,
        task_template: TaskTemplate,
        inputs: typing.Optional[typing.Dict[str, typing.Any]],
        **kwargs,
    ) -> ModelTrainJobMeta:
        """
        Submit training job via POST.
        Response returns job_id we later use in get().
        """
        custom = json_format.MessageToDict(task_template.custom) if task_template.custom else None
        async with httpx.AsyncClient() as client:
            r = await client.post(
                custom["endpoint"],
                json={"dataset_uri": inputs["dataset_uri"], "epochs": inputs["epochs"]},
            )
        r.raise_for_status()
        return ModelTrainJobMeta(job_id=r.json()["job_id"], endpoint=custom["endpoint"])

    async def get(self, resource_meta: ModelTrainJobMeta, **kwargs) -> Resource:
        """
        Poll external API until training job finishes.
        Must be safe to call repeatedly.
        """
        async with httpx.AsyncClient() as client:
            r = await client.get(f"{resource_meta.endpoint}/{resource_meta.job_id}")

        data = r.json()

        if data["status"] == "finished":
            return Resource(
                phase=TaskExecution.SUCCEEDED,
                log_links=[TaskLog(name="training-dashboard", uri=f"https://example-mltrain.com/train/{resource_meta.job_id}")],
                outputs={"results": data["results"]},
            )

        return Resource(phase=TaskExecution.RUNNING)

    async def delete(self, resource_meta: ModelTrainJobMeta, **kwargs):
        """
        Optionally call DELETE on external API.
        Safe even if job already completed.
        """
        async with httpx.AsyncClient() as client:
            await client.delete(f"{resource_meta.endpoint}/{resource_meta.job_id}")
```

To use this connector, you should define a task whose `task_type` matches the connector.

```python
import flyte.io
from typing import Any, Dict, Optional

from flyte.extend import TaskTemplate
from flyte.connectors import AsyncConnectorExecutorMixin
from flyte.models import NativeInterface, SerializationContext

class ModelTrainTask(AsyncConnectorExecutorMixin, TaskTemplate):
    _TASK_TYPE = "external_model_training"

    def __init__(
        self,
        name: str,
        endpoint: str,
        **kwargs,
    ):
        super().__init__(
            name=name,
            interface=NativeInterface(
                inputs={"epochs": int, "dataset_uri": str},
                outputs={"results": flyte.io.File},
            ),
            task_type=self._TASK_TYPE,
            **kwargs,
        )
        self.endpoint = endpoint

    def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
        return {"endpoint": self.endpoint}
```

Here is an example of how to use the `ModelTrainTask`:

```python
import flyte
from flyteplugins.model_training import ModelTrainTask

model_train_task = ModelTrainTask(
    name="model_train",
    endpoint="https://example-mltrain.com",
)

model_train_env = flyte.TaskEnvironment.from_task("model_train_env", model_train_task)

env = flyte.TaskEnvironment(
    name="hello_world",
    resources=flyte.Resources(memory="250Mi"),
    image=flyte.Image.from_debian_base(name="model_training").with_pip_packages(
        "flyteplugins-model-training", pre=True
    ),
    depends_on=[model_train_env],
)

@env.task
def data_prep() -> str:
    return "gs://my-bucket/dataset.csv"

@env.task
def train_model(epochs: int) -> flyte.io.File:
    dataset_uri = data_prep()
    return model_train_task(epochs=epochs, dataset_uri=dataset_uri)
```

### Build a custom connector image

Build a custom image when you're ready to deploy your connector to your cluster.
To build the Docker image for your connector, run the following script:

```python
import asyncio
from flyte import Image
from flyte.extend import ImageBuildEngine

async def build_flyte_connector_bigquery_image(registry: str, name: str, builder: str = "local"):
    """
    Build the SDK default connector image optionally overriding
    the container registry and image name.

    Args:
        registry: e.g. "ghcr.io/my-org" or "123456789012.dkr.ecr.us-west-2.amazonaws.com".
        name:     e.g. "my-connector".
        builder:  e.g. "local" or "remote".
    """

    default_image = Image.from_debian_base(
        registry=registry, name=name
    ).with_pip_packages("flyteintegrations-bigquery", pre=True)
    await ImageBuildEngine.build(default_image, builder=builder)

if __name__ == "__main__":
    print("Building connector image...")
    asyncio.run(
        build_flyte_connector_bigquery_image(
            registry="<YOUR_REGISTRY>", name="flyte-bigquery", builder="local"
        )
    )
```

## LLM Serving

LLM serving integrations let you deploy and serve large language models as Flyte apps with an OpenAI-compatible API. They handle model loading, GPU management, and autoscaling.

### Supported LLM serving integrations

| Plugin                                                            | Description                                           | Common use cases                     |
| ----------------------------------------------------------------- | ----------------------------------------------------- | ------------------------------------ |
| [SGLang](https://www.union.ai/docs/v2/union/user-guide/build-apps/sglang-app/page.md)  | Deploy models with SGLang's high-throughput runtime   | LLM inference, model serving         |
| [vLLM](https://www.union.ai/docs/v2/union/user-guide/build-apps/vllm-app/page.md)      | Deploy models with vLLM's PagedAttention engine       | LLM inference, model serving         |

For full setup instructions including multi-GPU deployment, model prefetching, and autoscaling, see the [SGLang app](https://www.union.ai/docs/v2/union/user-guide/build-apps/sglang-app/page.md) and [vLLM app](https://www.union.ai/docs/v2/union/user-guide/build-apps/vllm-app/page.md) pages.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/anthropic ===

# Anthropic

The Anthropic plugin lets you build agentic workflows with [Claude](https://www.anthropic.com/) on Flyte. It provides a `function_tool` decorator that wraps Flyte tasks as tools that Claude can call, and a `run_agent` function that drives the agent conversation loop.

When Claude calls a tool, the call executes as a Flyte task with full observability, retries, and caching.

## Installation

```bash
pip install flyteplugins-anthropic
```

Requires `anthropic >= 0.40.0`.

## Quick start

```python
import flyte
from flyteplugins.anthropic import function_tool, run_agent

env = flyte.TaskEnvironment(
    name="claude-agent",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="anthropic_agent"),
    secrets=flyte.Secret("anthropic_api_key", as_env_var="ANTHROPIC_API_KEY"),
)

@function_tool
@env.task
async def get_weather(city: str) -> str:
    """Get the current weather for a city."""
    return f"The weather in {city} is sunny, 72F"

@env.task
async def main(prompt: str) -> str:
    tools = [get_weather]
    return await run_agent(prompt=prompt, tools=tools)
```

## API

### `function_tool`

Converts a Flyte task, `@flyte.trace`-decorated function, or plain callable into a tool that Claude can invoke.

```python
@function_tool
@env.task
async def my_tool(param: str) -> str:
    """Tool description sent to Claude."""
    ...
```

Can also be called with optional overrides:

```python
@function_tool(name="custom_name", description="Custom description")
@env.task
async def my_tool(param: str) -> str:
    ...
```

Parameters:

| Parameter | Type | Description |
|-----------|------|-------------|
| `func` | callable | The function to wrap |
| `name` | `str` | Override the tool name (defaults to the function name) |
| `description` | `str` | Override the tool description (defaults to the docstring) |

> [!NOTE]
> The docstring on each `@function_tool` task is sent to Claude as the tool description. Write clear, concise docstrings.

### `Agent`

A dataclass for bundling agent configuration:

```python
from flyteplugins.anthropic import Agent

agent = Agent(
    name="my-agent",
    instructions="You are a helpful assistant.",
    model="claude-sonnet-4-20250514",
    tools=[get_weather],
    max_tokens=4096,
    max_iterations=10,
)
```

| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `name` | `str` | `"assistant"` | Agent name |
| `instructions` | `str` | `"You are a helpful assistant."` | System prompt |
| `model` | `str` | `"claude-sonnet-4-20250514"` | Claude model ID |
| `tools` | `list[FunctionTool]` | `[]` | Tools available to the agent |
| `max_tokens` | `int` | `4096` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum tool-call loop iterations |

### `run_agent`

Runs a Claude conversation loop, dispatching tool calls to Flyte tasks until Claude returns a final response.

```python
result = await run_agent(
    prompt="What's the weather in Tokyo?",
    tools=[get_weather],
    model="claude-sonnet-4-20250514",
)
```

You can also pass an `Agent` object:

```python
result = await run_agent(prompt="What's the weather?", agent=agent)
```

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `prompt` | `str` | required | User message |
| `tools` | `list[FunctionTool]` | `None` | Tools available to the agent |
| `agent` | `Agent` | `None` | Agent config (overrides individual params) |
| `model` | `str` | `"claude-sonnet-4-20250514"` | Claude model ID |
| `system` | `str` | `None` | System prompt |
| `max_tokens` | `int` | `4096` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum iterations (prevents infinite loops) |
| `api_key` | `str` | `None` | API key (falls back to `ANTHROPIC_API_KEY` env var) |

## Secrets

Store your Anthropic API key as a Flyte secret and expose it as an environment variable:

```python
secrets=flyte.Secret("anthropic_api_key", as_env_var="ANTHROPIC_API_KEY")
```

## API reference

See the [Anthropic API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/anthropic/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/bigquery ===

# BigQuery

The BigQuery connector lets you run SQL queries against [Google BigQuery](https://cloud.google.com/bigquery) directly from Flyte tasks. Queries are submitted asynchronously via the BigQuery Jobs API and polled for completion, so they don't block a worker while waiting for results.

The connector supports:

- Parameterized SQL queries with typed inputs
- Google Cloud service account authentication
- Returns query results as DataFrames
- Query cancellation on task abort

## Installation

```bash
pip install flyteplugins-bigquery
```

This installs the Google Cloud BigQuery client libraries.

## Quick start

Here's a minimal example that runs a SQL query on BigQuery:

```python
from flyte.io import DataFrame
from flyteplugins.bigquery import BigQueryConfig, BigQueryTask

config = BigQueryConfig(
    ProjectID="my-gcp-project",
    Location="US",
)

count_users = BigQueryTask(
    name="count_users",
    query_template="SELECT COUNT(*) FROM dataset.users",
    plugin_config=config,
    output_dataframe_type=DataFrame,
)
```

This defines a task called `count_users` that runs the query on the configured BigQuery instance. When executed, the connector:

1. Connects to BigQuery using the provided configuration
2. Submits the query asynchronously via the Jobs API
3. Polls until the query completes or fails

To run the task, create a `TaskEnvironment` from it and execute it locally or remotely:

```python
import flyte

bigquery_env = flyte.TaskEnvironment.from_task("bigquery_env", count_users)

if __name__ == "__main__":
    flyte.init_from_config()

    # Run locally (connector runs in-process, requires credentials locally)
    run = flyte.with_runcontext(mode="local").run(count_users)

    # Run remotely (connector runs on the control plane)
    run = flyte.with_runcontext(mode="remote").run(count_users)

    print(run.url)
```

> [!NOTE]
> The `TaskEnvironment` created by `from_task` does not need an image or pip packages. BigQuery tasks are connector tasks, which means the query executes on the connector service, not in your task container. In `local` mode, the connector runs in-process and requires `flyteplugins-bigquery` and credentials to be available on your machine.

## Configuration

### `BigQueryConfig` parameters

| Field | Type | Required | Description |
|-------|------|----------|-------------|
| `ProjectID` | `str` | Yes | GCP project ID |
| `Location` | `str` | No | BigQuery region (e.g., `"US"`, `"EU"`) |
| `QueryJobConfig` | `bigquery.QueryJobConfig` | No | Native BigQuery [QueryJobConfig](https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.QueryJobConfig) object for advanced settings |

### `BigQueryTask` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `name` | `str` | Unique task name |
| `query_template` | `str` | SQL query (whitespace is normalized before execution) |
| `plugin_config` | `BigQueryConfig` | Connection configuration |
| `inputs` | `Dict[str, Type]` | Named typed inputs bound as query parameters |
| `output_dataframe_type` | `Type[DataFrame]` | If set, query results are returned as a `DataFrame` |
| `google_application_credentials` | `str` | Name of the Flyte secret containing the GCP service account JSON key |

## Authentication

Pass the name of a Flyte secret containing your GCP service account JSON key:

```python
query = BigQueryTask(
    name="secure_query",
    query_template="SELECT * FROM dataset.sensitive_data",
    plugin_config=config,
    google_application_credentials="my-gcp-sa-key",
)
```

## Query templating

Use the `inputs` parameter to define typed inputs for your query. Input values are bound as BigQuery `ScalarQueryParameter` values.

### Supported input types

| Python type | BigQuery type |
|-------------|---------------|
| `int` | `INT64` |
| `float` | `FLOAT64` |
| `str` | `STRING` |
| `bool` | `BOOL` |
| `bytes` | `BYTES` |
| `datetime` | `DATETIME` |
| `list` | `ARRAY` |

### Parameterized query example

```python
from flyte.io import DataFrame

events_by_region = BigQueryTask(
    name="events_by_region",
    query_template="SELECT * FROM dataset.events WHERE region = @region AND score > @min_score",
    plugin_config=config,
    inputs={"region": str, "min_score": float},
    output_dataframe_type=DataFrame,
)
```

> [!NOTE]
> The query template is normalized before execution: newlines and tabs are replaced with spaces and consecutive whitespace is collapsed. You can format your queries across multiple lines for readability without affecting execution.

## Retrieving query results

Set `output_dataframe_type` to capture results as a DataFrame:

```python
from flyte.io import DataFrame

top_customers = BigQueryTask(
    name="top_customers",
    query_template="""
        SELECT customer_id, SUM(amount) AS total_spend
        FROM dataset.orders
        GROUP BY customer_id
        ORDER BY total_spend DESC
        LIMIT 100
    """,
    plugin_config=config,
    output_dataframe_type=DataFrame,
)
```

If you don't need query results (for example, DDL statements or INSERT queries), omit `output_dataframe_type`.

## API reference

See the [BigQuery API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/bigquery/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/dask ===

# Dask

The Dask plugin lets you run [Dask](https://www.dask.org/) jobs natively on Kubernetes. Flyte provisions a transient Dask cluster for each task execution using the [Dask Kubernetes Operator](https://kubernetes.dask.org/en/latest/operator.html) and tears it down on completion.

## When to use this plugin

- Parallel Python workloads that outgrow a single machine
- Distributed DataFrame operations on large datasets
- Workloads that use Dask's task scheduler for arbitrary computation graphs
- Jobs that need to scale NumPy, pandas, or scikit-learn workflows across multiple nodes

## Installation

```bash
pip install flyteplugins-dask
```

Your task image must also include the Dask distributed scheduler:

```python
image = flyte.Image.from_debian_base(name="dask").with_pip_packages("flyteplugins-dask")
```

## Configuration

Create a `Dask` configuration and pass it as `plugin_config` to a `TaskEnvironment`:

```python
from flyteplugins.dask import Dask, Scheduler, WorkerGroup

dask_config = Dask(
    scheduler=Scheduler(),
    workers=WorkerGroup(number_of_workers=4),
)

dask_env = flyte.TaskEnvironment(
    name="dask_env",
    plugin_config=dask_config,
    image=image,
)
```

### `Dask` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `scheduler` | `Scheduler` | Scheduler pod configuration (defaults to `Scheduler()`) |
| `workers` | `WorkerGroup` | Worker group configuration (defaults to `WorkerGroup()`) |

### `Scheduler` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `image` | `str` | Custom scheduler image (must include `dask[distributed]`) |
| `resources` | `Resources` | Resource requests for the scheduler pod |

### `WorkerGroup` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `number_of_workers` | `int` | Number of worker pods (default: `1`) |
| `image` | `str` | Custom worker image (must include `dask[distributed]`) |
| `resources` | `Resources` | Resource requests per worker pod |

> [!NOTE]
> The scheduler and all workers should use the same Python environment to avoid serialization issues.

### Accessing the Dask client

Inside a Dask task, create a `distributed.Client()` with no arguments. It automatically connects to the provisioned cluster:

```python
from distributed import Client

@dask_env.task
async def my_dask_task(n: int) -> list:
    client = Client()
    futures = client.map(lambda x: x + 1, range(n))
    return client.gather(futures)
```

## Example

```python
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-dask",
#    "distributed"
# ]
# main = "hello_dask_nested"
# params = ""
# ///

import asyncio
import typing

from distributed import Client
from flyteplugins.dask import Dask, Scheduler, WorkerGroup

import flyte.remote
import flyte.storage
from flyte import Resources

image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("flyteplugins-dask")

dask_config = Dask(
    scheduler=Scheduler(),
    workers=WorkerGroup(number_of_workers=4),
)

task_env = flyte.TaskEnvironment(
    name="hello_dask", resources=Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
dask_env = flyte.TaskEnvironment(
    name="dask_env",
    plugin_config=dask_config,
    image=image,
    resources=Resources(cpu="1", memory="1Gi"),
    depends_on=[task_env],
)

@task_env.task()
async def hello_dask():
    await asyncio.sleep(5)
    print("Hello from the Dask task!")

@dask_env.task
async def hello_dask_nested(n: int = 3) -> typing.List[int]:
    print("running dask task")
    t = asyncio.create_task(hello_dask())
    client = Client()
    futures = client.map(lambda x: x + 1, range(n))
    res = client.gather(futures)
    await t
    return res

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(hello_dask_nested)
    print(r.name)
    print(r.url)
    r.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/dask/dask_example.py*

## API reference

See the [Dask API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/dask/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/databricks ===

# Databricks

The Databricks plugin lets you run PySpark jobs on [Databricks](https://www.databricks.com/) clusters directly from Flyte tasks. You write normal PySpark code in a Flyte task, and the plugin submits it to Databricks via the [Jobs API 2.1](https://docs.databricks.com/api/workspace/jobs/submit). The connector handles job submission, polling, and cancellation.

The plugin supports:

- Running PySpark tasks on new or existing Databricks clusters
- Full Spark configuration (driver/executor memory, cores, instances)
- Databricks cluster auto-scaling
- API token-based authentication

## Installation

```bash
pip install flyteplugins-databricks
```

This also installs `flyteplugins-spark` as a dependency, since the Databricks plugin extends the Spark plugin.

## Quick start

Create a `Databricks` configuration and pass it as `plugin_config` to a `TaskEnvironment`:

```python
from flyteplugins.databricks import Databricks
import flyte

image = (
    flyte.Image.from_base("databricksruntime/standard:16.4-LTS")
    .clone(name="spark", registry="ghcr.io/flyteorg", extendable=True)
    .with_env_vars({"UV_PYTHON": "/databricks/python3/bin/python"})
    .with_pip_packages("flyteplugins-databricks", pre=True)
)

databricks_conf = Databricks(
    spark_conf={
        "spark.driver.memory": "2000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
    },
    executor_path="/databricks/python3/bin/python",
    databricks_conf={
        "run_name": "flyte databricks plugin",
        "new_cluster": {
            "spark_version": "13.3.x-scala2.12",
            "node_type_id": "m6i.large",
            "autoscale": {"min_workers": 1, "max_workers": 2},
        },
        "timeout_seconds": 3600,
        "max_retries": 1,
    },
    databricks_instance="myaccount.cloud.databricks.com",
    databricks_token="DATABRICKS_TOKEN",
)

databricks_env = flyte.TaskEnvironment(
    name="databricks_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("3000Mi", "5000Mi")),
    plugin_config=databricks_conf,
    image=image,
)
```

Then use the environment to decorate your task:

```python
@databricks_env.task
async def hello_databricks() -> float:
    spark = flyte.ctx().data["spark_session"]
    # Use spark as a normal SparkSession
    count = spark.sparkContext.parallelize(range(100)).count()
    return float(count)
```

## Configuration

The `Databricks` config extends the [Spark](../spark/_index) config with Databricks-specific fields.

### Spark fields (inherited)

| Parameter | Type | Description |
|-----------|------|-------------|
| `spark_conf` | `Dict[str, str]` | Spark configuration key-value pairs |
| `hadoop_conf` | `Dict[str, str]` | Hadoop configuration key-value pairs |
| `executor_path` | `str` | Path to the Python binary on the Databricks cluster (e.g., `/databricks/python3/bin/python`) |
| `applications_path` | `str` | Path to the main application file |

### Databricks-specific fields

| Parameter | Type | Description |
|-----------|------|-------------|
| `databricks_conf` | `Dict[str, Union[str, dict]]` | Databricks [run-submit](https://docs.databricks.com/api/workspace/jobs/submit) job configuration. Must contain either `existing_cluster_id` or `new_cluster` |
| `databricks_instance` | `str` | Your workspace domain (e.g., `myaccount.cloud.databricks.com`). Can also be set via the `FLYTE_DATABRICKS_INSTANCE` env var on the connector |
| `databricks_token` | `str` | Name of the Flyte secret containing the Databricks API token |

### `databricks_conf` structure

The `databricks_conf` dict maps to the Databricks run-submit API payload. Key fields:

| Field | Description |
|-------|-------------|
| `new_cluster` | Cluster spec with `spark_version`, `node_type_id`, `autoscale`, etc. |
| `existing_cluster_id` | ID of an existing cluster to use instead of creating a new one |
| `run_name` | Display name in the Databricks UI |
| `timeout_seconds` | Maximum job duration |
| `max_retries` | Number of retries before marking the job as failed |

The connector automatically injects the Docker image, Spark configuration, and environment variables from the task container into the cluster spec.

## Authentication

Store your Databricks API token as a Flyte secret. The `databricks_token` parameter specifies the secret name:

```python
databricks_conf = Databricks(
    # ...
    databricks_token="DATABRICKS_TOKEN",
)
```

## Accessing the Spark session

Inside a Databricks task, the `SparkSession` is available through the task context, just like the [Spark plugin](../spark/_index):

```python
@databricks_env.task
async def my_databricks_task() -> float:
    spark = flyte.ctx().data["spark_session"]
    df = spark.read.parquet("s3://my-bucket/data.parquet")
    return float(df.count())
```

## API reference

See the [Databricks API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/databricks/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/gemini ===

# Gemini

The Gemini plugin lets you build agentic workflows with [Gemini](https://ai.google.dev/) on Flyte. It provides a `function_tool` decorator that wraps Flyte tasks as tools that Gemini can call, and a `run_agent` function that drives the agent conversation loop.

When Gemini calls a tool, the call executes as a Flyte task with full observability, retries, and caching. Gemini's native parallel function calling is supported: multiple tool calls in a single turn are all dispatched and their results bundled into one response.

## Installation

```bash
pip install flyteplugins-gemini
```

Requires `google-genai >= 1.0.0`.

## Quick start

```python
import flyte
from flyteplugins.gemini import function_tool, run_agent

env = flyte.TaskEnvironment(
    name="gemini-agent",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="gemini_agent"),
    secrets=flyte.Secret("google_api_key", as_env_var="GOOGLE_API_KEY"),
)

@function_tool
@env.task
async def get_weather(city: str) -> str:
    """Get the current weather for a city."""
    return f"The weather in {city} is sunny, 72F"

@env.task
async def main(prompt: str) -> str:
    tools = [get_weather]
    return await run_agent(prompt=prompt, tools=tools)
```

## API

### `function_tool`

Converts a Flyte task, `@flyte.trace`-decorated function, or plain callable into a tool that Gemini can invoke.

```python
@function_tool
@env.task
async def my_tool(param: str) -> str:
    """Tool description sent to Gemini."""
    ...
```

Can also be called with optional overrides:

```python
@function_tool(name="custom_name", description="Custom description")
@env.task
async def my_tool(param: str) -> str:
    ...
```

Parameters:

| Parameter | Type | Description |
|-----------|------|-------------|
| `func` | callable | The function to wrap |
| `name` | `str` | Override the tool name (defaults to the function name) |
| `description` | `str` | Override the tool description (defaults to the docstring) |

> [!NOTE]
> The docstring on each `@function_tool` task is sent to Gemini as the tool description. Write clear, concise docstrings.

### `Agent`

A dataclass for bundling agent configuration:

```python
from flyteplugins.gemini import Agent

agent = Agent(
    name="my-agent",
    instructions="You are a helpful assistant.",
    model="gemini-2.5-flash",
    tools=[get_weather],
    max_output_tokens=8192,
    max_iterations=10,
)
```

| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `name` | `str` | `"assistant"` | Agent name |
| `instructions` | `str` | `"You are a helpful assistant."` | System prompt |
| `model` | `str` | `"gemini-2.5-flash"` | Gemini model ID |
| `tools` | `list[FunctionTool]` | `[]` | Tools available to the agent |
| `max_output_tokens` | `int` | `8192` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum tool-call loop iterations |

### `run_agent`

Runs a Gemini conversation loop, dispatching tool calls to Flyte tasks until Gemini returns a final response.

```python
result = await run_agent(
    prompt="What's the weather in Tokyo?",
    tools=[get_weather],
    model="gemini-2.5-flash",
)
```

You can also pass an `Agent` object:

```python
result = await run_agent(prompt="What's the weather?", agent=agent)
```

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `prompt` | `str` | required | User message |
| `tools` | `list[FunctionTool]` | `None` | Tools available to the agent |
| `agent` | `Agent` | `None` | Agent config (overrides individual params) |
| `model` | `str` | `"gemini-2.5-flash"` | Gemini model ID |
| `system` | `str` | `None` | System prompt |
| `max_output_tokens` | `int` | `8192` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum iterations (prevents infinite loops) |
| `api_key` | `str` | `None` | API key (falls back to `GOOGLE_API_KEY` env var) |

## Secrets

Store your Google API key as a Flyte secret and expose it as an environment variable:

```python
secrets=flyte.Secret("google_api_key", as_env_var="GOOGLE_API_KEY")
```

## API reference

See the [Gemini API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/gemini/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/openai ===

# OpenAI

The OpenAI plugin provides a drop-in replacement for the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/) `function_tool` decorator. It lets you use Flyte tasks as tools in agentic workflows so that tool calls run as tracked, reproducible Flyte task executions.

## When to use this plugin

- Building agentic workflows with the OpenAI Agents SDK on Flyte
- You want tool calls to run as Flyte tasks with full observability, retries, and caching
- You want to combine LLM agents with existing Flyte pipelines

## Installation

```bash
pip install flyteplugins-openai
```

Requires `openai-agents >= 0.2.4`.

## Usage

The plugin provides a single decorator, `function_tool`, that wraps Flyte tasks as OpenAI agent tools.

### `function_tool`

When applied to a Flyte task (a function decorated with `@env.task`), `function_tool` makes that task available as an OpenAI `FunctionTool`. The agent can call it like any other tool, and the call executes as a Flyte task.

When applied to a regular function or a `@flyte.trace`-decorated function, it delegates directly to the OpenAI Agents SDK's built-in `function_tool`.

### Basic pattern

1. Define a `TaskEnvironment` with your image and secrets
2. Decorate your task functions with `@function_tool` and `@env.task`
3. Pass the tools to an `Agent`
4. Run the agent from another Flyte task

```python
from agents import Agent, Runner
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny")

agent = Agent(
    name="Weather Agent",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    return result.final_output
```

> [!NOTE]
> The docstring on each `@function_tool` task is sent to the LLM as the tool description. Write clear, concise docstrings that describe what the tool does and what its parameters mean.

### Secrets

Store your OpenAI API key as a Flyte secret and expose it as an environment variable:

```python
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY")
```

## Example

```python
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

## API reference

See the [OpenAI API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/openai/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/openai/agent_tools ===

# Agent tools

In this example, we will use the `openai-agents` library to create a simple agent that can use tools to perform tasks.
This example is based on the [basic tools example](https://github.com/openai/openai-agents-python/blob/main/examples/basic/tools.py) example from the `openai-agents-python` repo.

First, create an OpenAI API key, which you can get from the [OpenAI website](https://platform.openai.com/account/api-keys).
Then, create a secret on your Flyte cluster with:

```
flyte create secret OPENAI_API_KEY --value <your-api-key>
```

Then, we'll use `uv script` to specify our dependencies.

```
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

Next, we'll import the libraries and create a `TaskEnvironment`, which we need to run the example:

```
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

## Define the tools

We'll define a tool that can get weather information for a
given city. In this case, we'll use a toy function that returns a hard-coded `Weather` object.

```
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

In this code snippet, the `@function_tool` decorator is imported from `flyteplugins.openai.agents`, which is a drop-in replacement for the `@function_tool` decorator from `openai-agents` library.

## Define the agent

Then, we'll define the agent, which calls the tool:

```
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

## Run the agent

Finally, we'll run the agent. Create `config.yaml` file, which the `flyte.init_from_config()` function will use to connect to
the Flyte cluster:

```bash
flyte create config \
--output ~/.flyte/config.yaml \
--endpoint demo.hosted.unionai.cloud/ \
--project flytesnacks \
--domain development \
--builder remote
```

```
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

## Conclusion

In this example, we've seen how to use the `openai-agents` library to create a simple agent that can use tools to perform tasks.

The full code is available [here](https://github.com/unionai/unionai-examples/tree/main/v2/integrations/flyte-plugins/openai/openai).

=== PAGE: https://www.union.ai/docs/v2/union/integrations/pytorch ===

# PyTorch

The PyTorch plugin lets you run distributed [PyTorch](https://pytorch.org/) training jobs natively on Kubernetes. It uses the [Kubeflow Training Operator](https://github.com/kubeflow/training-operator) to manage multi-node training with PyTorch's elastic launch (`torchrun`).

## When to use this plugin

- Single-node or multi-node distributed training with `DistributedDataParallel` (DDP)
- Elastic training that can scale up and down during execution
- Any workload that uses `torch.distributed` for data-parallel or model-parallel training

## Installation

```bash
pip install flyteplugins-pytorch
```

## Configuration

Create an `Elastic` configuration and pass it as `plugin_config` to a `TaskEnvironment`:

```python
from flyteplugins.pytorch import Elastic

torch_env = flyte.TaskEnvironment(
    name="torch_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
    plugin_config=Elastic(
        nnodes=2,
        nproc_per_node=1,
    ),
    image=image,
)
```

### `Elastic` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `nnodes` | `int` or `str` | **Required.** Number of nodes. Use an int for a fixed count or a range string (e.g., `"2:4"`) for elastic training |
| `nproc_per_node` | `int` | **Required.** Number of processes (workers) per node |
| `rdzv_backend` | `str` | Rendezvous backend: `"c10d"` (default), `"etcd"`, or `"etcd-v2"` |
| `max_restarts` | `int` | Maximum worker group restarts (default: `3`) |
| `monitor_interval` | `int` | Agent health check interval in seconds (default: `3`) |
| `run_policy` | `RunPolicy` | Job run policy (cleanup, TTL, deadlines, retries) |

### `RunPolicy` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `clean_pod_policy` | `str` | Pod cleanup policy: `"None"`, `"all"`, or `"Running"` |
| `ttl_seconds_after_finished` | `int` | Seconds to keep pods after job completion |
| `active_deadline_seconds` | `int` | Maximum time the job can run (seconds) |
| `backoff_limit` | `int` | Number of retries before marking the job as failed |

### NCCL tuning parameters

The plugin includes built-in NCCL timeout tuning to reduce failure-detection latency (PyTorch defaults to 1800 seconds):

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `nccl_heartbeat_timeout_sec` | `int` | `300` | NCCL heartbeat timeout (seconds) |
| `nccl_async_error_handling` | `bool` | `False` | Enable async NCCL error handling |
| `nccl_collective_timeout_sec` | `int` | `None` | Timeout for NCCL collective operations |
| `nccl_enable_monitoring` | `bool` | `True` | Enable NCCL monitoring |

### Writing a distributed training task

Tasks using this plugin do not need to be `async`. Initialize the process group and use `DistributedDataParallel` as you normally would with `torchrun`:

```python
import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel as DDP

@torch_env.task
def train(epochs: int) -> float:
    torch.distributed.init_process_group("gloo")
    model = DDP(MyModel())
    # ... training loop ...
    return final_loss
```

> [!NOTE]
> When `nnodes=1`, the task runs as a regular Python task (no Kubernetes training job is created). Set `nnodes >= 2` for multi-node distributed training.

## Example

```python
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-pytorch",
#    "torch"
# ]
# main = "torch_distributed_train"
# params = "3"
# ///

import typing

import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from flyteplugins.pytorch.task import Elastic
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset

import flyte

image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch", pre=True)

torch_env = flyte.TaskEnvironment(
    name="torch_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
    plugin_config=Elastic(
        nproc_per_node=1,
        # if you want to do local testing set nnodes=1
        nnodes=2,
    ),
    image=image,
)

class LinearRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

def prepare_dataloader(rank: int, world_size: int, batch_size: int = 2) -> DataLoader:
    """
    Prepare a DataLoader with a DistributedSampler so each rank
    gets a shard of the dataset.
    """
    # Dummy dataset
    x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
    y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])
    dataset = TensorDataset(x_train, y_train)

    # Distributed-aware sampler
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)

    return DataLoader(dataset, batch_size=batch_size, sampler=sampler)

def train_loop(epochs: int = 3) -> float:
    """
    A simple training loop for linear regression.
    """
    torch.distributed.init_process_group("gloo")
    model = DDP(LinearRegressionModel())

    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()

    dataloader = prepare_dataloader(
        rank=rank,
        world_size=world_size,
        batch_size=64,
    )

    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    final_loss = 0.0

    for _ in range(epochs):
        for x, y in dataloader:
            outputs = model(x)
            loss = criterion(outputs, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            final_loss = loss.item()
        if torch.distributed.get_rank() == 0:
            print(f"Loss: {final_loss}")

    return final_loss

@torch_env.task
def torch_distributed_train(epochs: int) -> typing.Optional[float]:
    """
    A nested task that sets up a simple distributed training job using PyTorch's
    """
    print("starting launcher")
    loss = train_loop(epochs=epochs)
    print("Training complete")
    return loss

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(torch_distributed_train, epochs=3)
    print(r.name)
    print(r.url)
    r.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/pytorch/pytorch_example.py*

## API reference

See the [PyTorch API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/pytorch/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/ray ===

# Ray

The Ray plugin lets you run [Ray](https://www.ray.io/) jobs natively on Kubernetes. Flyte provisions a transient Ray cluster for each task execution using [KubeRay](https://github.com/ray-project/kuberay) and tears it down on completion.

## When to use this plugin

- Distributed Python workloads (parallel computation, data processing)
- ML training with Ray Train or hyperparameter tuning with Ray Tune
- Ray Serve inference workloads
- Any workload that benefits from Ray's actor model or task parallelism

## Installation

```bash
pip install flyteplugins-ray
```

Your task image must also include a compatible version of Ray:

```python
image = (
    flyte.Image.from_debian_base(name="ray")
    .with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray")
)
```

## Configuration

Create a `RayJobConfig` and pass it as `plugin_config` to a `TaskEnvironment`:

```python
from flyteplugins.ray import HeadNodeConfig, RayJobConfig, WorkerNodeConfig

ray_config = RayJobConfig(
    head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
    worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
    runtime_env={"pip": ["numpy", "pandas"]},
    enable_autoscaling=False,
    shutdown_after_job_finishes=True,
    ttl_seconds_after_finished=300,
)

ray_env = flyte.TaskEnvironment(
    name="ray_env",
    plugin_config=ray_config,
    image=image,
)
```

### `RayJobConfig` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `worker_node_config` | `List[WorkerNodeConfig]` | **Required.** List of worker group configurations |
| `head_node_config` | `HeadNodeConfig` | Head node configuration (optional) |
| `enable_autoscaling` | `bool` | Enable Ray autoscaler (default: `False`) |
| `runtime_env` | `dict` | Ray runtime environment (pip packages, env vars, etc.) |
| `address` | `str` | Connect to an existing Ray cluster instead of provisioning one |
| `shutdown_after_job_finishes` | `bool` | Shut down the cluster after the job completes (default: `False`) |
| `ttl_seconds_after_finished` | `int` | Seconds to keep the cluster after completion before cleanup |

### `WorkerNodeConfig` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `group_name` | `str` | **Required.** Name of this worker group |
| `replicas` | `int` | **Required.** Number of worker replicas |
| `min_replicas` | `int` | Minimum replicas (for autoscaling) |
| `max_replicas` | `int` | Maximum replicas (for autoscaling) |
| `ray_start_params` | `Dict[str, str]` | Ray start parameters for workers |
| `requests` | `Resources` | Resource requests per worker |
| `limits` | `Resources` | Resource limits per worker |
| `pod_template` | `PodTemplate` | Full pod template (mutually exclusive with `requests`/`limits`) |

### `HeadNodeConfig` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `ray_start_params` | `Dict[str, str]` | Ray start parameters for the head node |
| `requests` | `Resources` | Resource requests for the head node |
| `limits` | `Resources` | Resource limits for the head node |
| `pod_template` | `PodTemplate` | Full pod template (mutually exclusive with `requests`/`limits`) |

### Connecting to an existing cluster

To connect to an existing Ray cluster instead of provisioning a new one, set the `address` parameter:

```python
ray_config = RayJobConfig(
    worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
    address="ray://existing-cluster:10001",
)
```

## Examples

The following example shows how to configure Ray in a `TaskEnvironment`. Flyte automatically provisions a Ray cluster for each task using this configuration:

```python
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-ray",
#    "ray[default]==2.46.0"
# ]
# main = "hello_ray_nested"
# params = "3"
# ///

import asyncio
import typing

import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig

import flyte.remote
import flyte.storage

@ray.remote
def f(x):
    return x * x

ray_config = RayJobConfig(
    head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
    worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
    runtime_env={"pip": ["numpy", "pandas"]},
    enable_autoscaling=False,
    shutdown_after_job_finishes=True,
    ttl_seconds_after_finished=300,
)

image = (
    flyte.Image.from_debian_base(name="ray")
    .with_apt_packages("wget")
    .with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray", "pip", "mypy")
)

task_env = flyte.TaskEnvironment(
    name="hello_ray", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
    name="ray_env",
    plugin_config=ray_config,
    image=image,
    resources=flyte.Resources(cpu=(3, 4), memory=("3000Mi", "5000Mi")),
    depends_on=[task_env],
)

@task_env.task()
async def hello_ray():
    await asyncio.sleep(20)
    print("Hello from the Ray task!")

@ray_env.task
async def hello_ray_nested(n: int = 3) -> typing.List[int]:
    print("running ray task")
    t = asyncio.create_task(hello_ray())
    futures = [f.remote(i) for i in range(n)]
    res = ray.get(futures)
    await t
    return res

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(hello_ray_nested)
    print(r.name)
    print(r.url)
    r.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/ray/ray_example.py*

The next example demonstrates how Flyte can create ephemeral Ray clusters and run a subtask that connects to an existing Ray cluster:

```python
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-ray",
#    "ray[default]==2.46.0"
# ]
# main = "create_ray_cluster"
# params = ""
# ///

import os
import typing

import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig

import flyte.storage

@ray.remote
def f(x):
    return x * x

ray_config = RayJobConfig(
    head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
    worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
    enable_autoscaling=False,
    shutdown_after_job_finishes=True,
    ttl_seconds_after_finished=3600,
)

image = (
    flyte.Image.from_debian_base(name="ray")
    .with_apt_packages("wget")
    .with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray")
)

task_env = flyte.TaskEnvironment(
    name="ray_client", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
    name="ray_cluster",
    plugin_config=ray_config,
    image=image,
    resources=flyte.Resources(cpu=(2, 4), memory=("2000Mi", "4000Mi")),
    depends_on=[task_env],
)

@task_env.task()
async def hello_ray(cluster_ip: str) -> typing.List[int]:
    """
    Run a simple Ray task that connects to an existing Ray cluster.
    """
    ray.init(address=f"ray://{cluster_ip}:10001")
    futures = [f.remote(i) for i in range(5)]
    res = ray.get(futures)
    return res

@ray_env.task
async def create_ray_cluster() -> str:
    """
    Create a Ray cluster and return the head node IP address.
    """
    print("creating ray cluster")
    cluster_ip = os.getenv("MY_POD_IP")
    if cluster_ip is None:
        raise ValueError("MY_POD_IP environment variable is not set")
    return f"{cluster_ip}"

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(create_ray_cluster)
    run.wait()
    print("run url:", run.url)
    print("cluster created, running ray task")
    print("ray address:", run.outputs()[0])
    run = flyte.run(hello_ray, cluster_ip=run.outputs()[0])
    print("run url:", run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/ray/ray_existing_example.py*

## API reference

See the [Ray API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/ray/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/snowflake ===

# Snowflake

The Snowflake connector lets you run SQL queries against [Snowflake](https://www.snowflake.com/) directly from Flyte tasks. Queries are submitted asynchronously and polled for completion, so they don't block a worker while waiting for results.

The connector supports:

- Parameterized SQL queries with typed inputs
- Key-pair and password-based authentication
- Returns query results as DataFrames
- Automatic links to the Snowflake query dashboard in the Flyte UI
- Query cancellation on task abort

## Installation

```bash
pip install flyteplugins-snowflake
```

This installs the Snowflake Python connector and the `cryptography` library for key-pair authentication.

## Quick start

Here's a minimal example that runs a SQL query on Snowflake:

```python {hl_lines=[2, 4, 12]}
from flyte.io import DataFrame
from flyteplugins.connectors.snowflake import Snowflake, SnowflakeConfig

config = SnowflakeConfig(
    account="myorg-myaccount",
    user="flyte_user",
    database="ANALYTICS",
    schema="PUBLIC",
    warehouse="COMPUTE_WH",
)

count_users = Snowflake(
    name="count_users",
    query_template="SELECT COUNT(*) FROM users",
    plugin_config=config,
    output_dataframe_type=DataFrame,
)
```

This defines a task called `count_users` that runs `SELECT COUNT(*) FROM users` on the configured Snowflake instance. When executed, the connector:

1. Connects to Snowflake using the provided configuration
2. Submits the query asynchronously
3. Polls until the query completes or fails
4. Provides a link to the query in the Snowflake dashboard

![Snowflake Link](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/snowflake/ui.png)

To run the task, create a `TaskEnvironment` from it and execute it locally or remotely:

```python {hl_lines=3}
import flyte

snowflake_env = flyte.TaskEnvironment.from_task("snowflake_env", count_users)

if __name__ == "__main__":
    flyte.init_from_config()

    # Run locally (connector runs in-process, requires credentials and packages locally)
    run = flyte.with_runcontext(mode="local").run(count_users)

    # Run remotely (connector runs on the control plane)
    run = flyte.with_runcontext(mode="remote").run(count_users)

    print(run.url)
```

> [!NOTE]
> The `TaskEnvironment` created by `from_task` does not need an image or pip packages. Snowflake tasks are connector tasks, which means the query executes on the connector service, not in your task container. In `local` mode, the connector runs in-process and requires `flyteplugins-snowflake` and credentials to be available on your machine. In `remote` mode, the connector runs on the control plane.

## Configuration

The `SnowflakeConfig` dataclass defines the connection settings for your Snowflake instance.

### Required fields

| Field       | Type  | Description                                             |
| ----------- | ----- | ------------------------------------------------------- |
| `account`   | `str` | Snowflake account identifier (e.g. `"myorg-myaccount"`) |
| `database`  | `str` | Target database name                                    |
| `schema`    | `str` | Target schema name (e.g. `"PUBLIC"`)                    |
| `warehouse` | `str` | Compute warehouse to use for query execution            |
| `user`      | `str` | Snowflake username                                      |

### Additional connection parameters

Use `connection_kwargs` to pass any additional parameters supported by the [Snowflake Python connector](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api). This is a dictionary that gets forwarded directly to `snowflake.connector.connect()`.

Common options include:

| Parameter       | Type  | Description                                                                |
| --------------- | ----- | -------------------------------------------------------------------------- |
| `role`          | `str` | Snowflake role to use for the session                                      |
| `authenticator` | `str` | Authentication method (e.g. `"snowflake"`, `"externalbrowser"`, `"oauth"`) |
| `token`         | `str` | OAuth token when using `authenticator="oauth"`                             |
| `login_timeout` | `int` | Timeout in seconds for the login request                                   |

Example with a role:

```python {hl_lines=8}
config = SnowflakeConfig(
    account="myorg-myaccount",
    user="flyte_user",
    database="ANALYTICS",
    schema="PUBLIC",
    warehouse="COMPUTE_WH",
    connection_kwargs={
        "role": "DATA_ANALYST",
    },
)
```

## Authentication

The connector supports two authentication approaches: key-pair authentication, and password-based or other authentication methods provided through `connection_kwargs`.

### Key-pair authentication

Key-pair authentication is the recommended approach for automated workloads. Pass the names of the Flyte secrets containing the private key and optional passphrase:

```python {hl_lines=[5, 6]}
query = Snowflake(
    name="secure_query",
    query_template="SELECT * FROM sensitive_data",
    plugin_config=config,
    snowflake_private_key="my-snowflake-private-key",
    snowflake_private_key_passphrase="my-snowflake-pk-passphrase",
)
```

The `snowflake_private_key` parameter is the name of the secret (or secret key) that contains your PEM-encoded private key. The `snowflake_private_key_passphrase` parameter is the name of the secret (or secret key) that contains the passphrase, if your key is encrypted. If your key is not encrypted, omit the passphrase parameter.

The connector decodes the PEM key and converts it to DER format for Snowflake authentication.

> [!NOTE]
> If your credentials are stored in a secret group, you can pass `secret_group` to the `Snowflake` task. The plugin expects `snowflake_private_key` and
> `snowflake_private_key_passphrase` to be keys within the same secret group.

### Password authentication

Send the password via `connection_kwargs`:

```python {hl_lines=8}
config = SnowflakeConfig(
    account="myorg-myaccount",
    user="flyte_user",
    database="ANALYTICS",
    schema="PUBLIC",
    warehouse="COMPUTE_WH",
    connection_kwargs={
        "password": "my-password",
    },
)
```

### OAuth authentication

For OAuth-based authentication, specify the authenticator and token:

```python {hl_lines=["8-9"]}
config = SnowflakeConfig(
    account="myorg-myaccount",
    user="flyte_user",
    database="ANALYTICS",
    schema="PUBLIC",
    warehouse="COMPUTE_WH",
    connection_kwargs={
        "authenticator": "oauth",
        "token": "<oauth-token>",
    },
)
```

## Query templating

Use the `inputs` parameter to define typed inputs for your query. Input values are bound using the `%(param)s` syntax supported by the [Snowflake Python connector](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api), which handles type conversion and escaping automatically.

### Supported input types

The `inputs` dictionary maps parameter names to Python values. Supported scalar types include `str`, `int`, `float`, and `bool`.

To insert multiple rows in a single query, you can also provide lists as input values. When using list inputs, be sure to set `batch=True` on the `Snowflake` task. This enables automatic batching, where the inputs are expanded and sent as a single multi-row query instead of you having to write multiple individual statements.

### Batched `INSERT` with list inputs

When `batch=True` is enabled, a parameterized `INSERT` query with list inputs is automatically expanded into a multi-row `VALUES` statement.

Example:

```python
query = "INSERT INTO t (a, b) VALUES (%(a)s, %(b)s)"
inputs = {"a": [1, 2], "b": ["x", "y"]}
```

This is expanded into:

```sql
INSERT INTO t (a, b)
VALUES (%(a_0)s, %(b_0)s), (%(a_1)s, %(b_1)s)
```

with the following flattened parameters:

```python
flat_params = {
    "a_0": 1,
    "b_0": "x",
    "a_1": 2,
    "b_1": "y",
}
```

#### Constraints

- The query must contain exactly one `VALUES (...)` clause.
- All list inputs must have the same non-zero length.

### Parameterized `SELECT`

```python {hl_lines=[5, 7]}
from flyte.io import DataFrame

events_by_date = Snowflake(
    name="events_by_date",
    query_template="SELECT * FROM events WHERE event_date = %(event_date)s",
    plugin_config=config,
    inputs={"event_date": str},
    output_dataframe_type=DataFrame,
)
```

You can call the task with the required inputs:

```python {hl_lines=3}
@env.task
async def fetch_events() -> DataFrame:
    return await events_by_date(event_date="2024-01-15")
```

### Multiple inputs

You can define multiple input parameters of different types:

```python {hl_lines=["4-8", "12-15"]}
filtered_events = Snowflake(
    name="filtered_events",
    query_template="""
        SELECT * FROM events
        WHERE event_date >= %(start_date)s
          AND event_date <= %(end_date)s
          AND region = %(region)s
          AND score > %(min_score)s
    """,
    plugin_config=config,
    inputs={
        "start_date": str,
        "end_date": str,
        "region": str,
        "min_score": float,
    },
    output_dataframe_type=DataFrame,
)
```

> [!NOTE]
> The query template is normalized before execution: newlines and tabs are replaced with spaces, and consecutive whitespace is collapsed. You can format your queries across multiple lines for readability without affecting execution.

## Retrieving query results

If your query produces output, set `output_dataframe_type` to capture the results. `output_dataframe_type` accepts `DataFrame` from `flyte.io`. This is a meta-wrapper type that represents tabular results and can be materialized into a concrete DataFrame implementation using `open()` where you specify the target type and `all()`.

```python {hl_lines=13}
from flyte.io import DataFrame

top_customers = Snowflake(
    name="top_customers",
    query_template="""
        SELECT customer_id, SUM(amount) AS total_spend
        FROM orders
        GROUP BY customer_id
        ORDER BY total_spend DESC
        LIMIT 100
    """,
    plugin_config=config,
    output_dataframe_type=DataFrame,
)
```

At present, only `pandas.DataFrame` is supported. The results are returned directly when you call the task:

```python {hl_lines=6}
import pandas as pd

@env.task
async def analyze_top_customers() -> dict:
    df = await top_customers()
    pandas_df = await df.open(pd.DataFrame).all()
    total_spend = pandas_df["total_spend"].sum()
    return {"total_spend": float(total_spend)}
```

If you specify `pandas.DataFrame` as the `output_dataframe_type`, you do not need to call `open()` and `all()` to materialize the results.

```python {hl_lines=[1, 13, "18-19"]}
import pandas as pd

top_customers = Snowflake(
    name="top_customers",
    query_template="""
        SELECT customer_id, SUM(amount) AS total_spend
        FROM orders
        GROUP BY customer_id
        ORDER BY total_spend DESC
        LIMIT 100
    """,
    plugin_config=config,
    output_dataframe_type=pd.DataFrame,
)

@env.task
async def analyze_top_customers() -> dict:
    df = await top_customers()
    total_spend = df["total_spend"].sum()
    return {"total_spend": float(total_spend)}
```

> [!NOTE]
> Be sure to inject the `SNOWFLAKE_PRIVATE_KEY` and `SNOWFLAKE_PRIVATE_KEY_PASSPHRASE` environment variables as secrets into your downstream tasks, as they must have access to Snowflake credentials in order to retrieve the DataFrame results. More on how you can create secrets [here](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets/page.md).

If you don't need query results (for example, `DDL` statements or `INSERT` queries), omit `output_dataframe_type`.

## End-to-end example

Here's a complete workflow that uses the Snowflake connector as part of a data pipeline. The workflow creates a staging table, inserts records, queries aggregated results and processes them in a downstream task.

```
import flyte
from flyte.io import DataFrame
from flyteplugins.connectors.snowflake import Snowflake, SnowflakeConfig

config = SnowflakeConfig(
    account="myorg-myaccount",
    user="flyte_user",
    database="ANALYTICS",
    schema="PUBLIC",
    warehouse="COMPUTE_WH",
    connection_kwargs={
        "role": "ETL_ROLE",
    },
)

# Step 1: Create the staging table if it doesn't exist
create_staging = Snowflake(
    name="create_staging",
    query_template="""
        CREATE TABLE IF NOT EXISTS staging.daily_events (
            event_id STRING,
            event_date DATE,
            user_id STRING,
            event_type STRING,
            payload VARIANT
        )
    """,
    plugin_config=config,
    snowflake_private_key="snowflake",
    snowflake_private_key_passphrase="snowflake_passphrase",
)

# Step 2: Insert a record into the staging table
insert_events = Snowflake(
    name="insert_event",
    query_template="""
        INSERT INTO staging.daily_events (event_id, event_date, user_id, event_type)
        VALUES (%(event_id)s, %(event_date)s, %(user_id)s, %(event_type)s)
    """,
    plugin_config=config,
    inputs={
        "event_id": list[str],
        "event_date": list[str],
        "user_id": list[str],
        "event_type": list[str],
    },
    snowflake_private_key="snowflake",
    snowflake_private_key_passphrase="snowflake_passphrase",
    batch=True,
)

# Step 3: Query aggregated results for a given date
daily_summary = Snowflake(
    name="daily_summary",
    query_template="""
        SELECT
            event_type,
            COUNT(*) AS event_count,
            COUNT(DISTINCT user_id) AS unique_users
        FROM staging.daily_events
        WHERE event_date = %(report_date)s
        GROUP BY event_type
        ORDER BY event_count DESC
    """,
    plugin_config=config,
    inputs={"report_date": str},
    output_dataframe_type=DataFrame,
    snowflake_private_key="snowflake",
    snowflake_private_key_passphrase="snowflake_passphrase",
)

# Create environments for all Snowflake tasks
snowflake_env = flyte.TaskEnvironment.from_task(
    "snowflake_env", create_staging, insert_events, daily_summary
)

# Main pipeline environment that depends on the Snowflake task environments
env = flyte.TaskEnvironment(
    name="analytics_env",
    resources=flyte.Resources(memory="512Mi"),
    image=flyte.Image.from_debian_base(name="analytics").with_pip_packages(
        "flyteplugins-snowflake", pre=True
    ),
    secrets=[
        flyte.Secret(key="snowflake", as_env_var="SNOWFLAKE_PRIVATE_KEY"),
        flyte.Secret(
            key="snowflake_passphrase", as_env_var="SNOWFLAKE_PRIVATE_KEY_PASSPHRASE"
        ),
    ],
    depends_on=[snowflake_env],
)

# Step 4: Process the results in Python
@env.task
async def generate_report(summary: DataFrame) -> dict:
    import pandas as pd

    df = await summary.open(pd.DataFrame).all()
    total_events = df["event_count"].sum()
    top_event = df.iloc[0]["event_type"]
    return {
        "total_events": int(total_events),
        "top_event_type": top_event,
        "event_types_count": len(df),
    }

# Compose the pipeline
@env.task
async def run_daily_pipeline(
    event_ids: list[str],
    event_dates: list[str],
    user_ids: list[str],
    event_types: list[str],
) -> dict:
    await create_staging()
    await insert_events(
        event_id=event_ids,
        event_date=event_dates,
        user_id=user_ids,
        event_type=event_types,
    )
    summary = await daily_summary(report_date=event_dates[0])
    return await generate_report(summary=summary)

if __name__ == "__main__":
    flyte.init_from_config()

    # Run locally
    run = flyte.with_runcontext(mode="local").run(
        run_daily_pipeline,
        event_ids=["event-1", "event-2"],
        event_dates=["2023-01-01", "2023-01-02"],
        user_ids=["user-1", "user-2"],
        event_types=["click", "view"],
    )

    # Or run remotely
    run = flyte.with_runcontext(mode="remote").run(
        run_daily_pipeline,
        event_ids=["event-1", "event-2"],
        event_dates=["2023-01-01", "2023-01-02"],
        user_ids=["user-1", "user-2"],
        event_types=["click", "view"],
    )

    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/connectors/snowflake/example.py*

=== PAGE: https://www.union.ai/docs/v2/union/integrations/spark ===

# Spark

The Spark plugin lets you run [Apache Spark](https://spark.apache.org/) jobs natively on Kubernetes. Flyte manages the full cluster lifecycle: provisioning a transient Spark cluster for each task execution, running the job, and tearing the cluster down on completion.

Under the hood, the plugin uses the [Spark on Kubernetes Operator](https://github.com/GoogleCloudPlatform/spark-on-k8s-operator) to create and manage Spark applications. No external Spark service or long-running cluster is required.

## When to use this plugin

- Large-scale data processing and ETL pipelines
- Jobs that benefit from Spark's distributed execution engine (Spark SQL, PySpark, Spark MLlib)
- Workloads that need Hadoop-compatible storage access (S3, GCS, HDFS)

## Installation

```bash
pip install flyteplugins-spark
```

## Configuration

Create a `Spark` configuration and pass it as `plugin_config` to a `TaskEnvironment`:

```python
from flyteplugins.spark import Spark

spark_config = Spark(
    spark_conf={
        "spark.driver.memory": "3000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
    },
)

spark_env = flyte.TaskEnvironment(
    name="spark_env",
    plugin_config=spark_config,
    image=image,
)
```

### `Spark` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `spark_conf` | `Dict[str, str]` | Spark configuration key-value pairs (e.g., executor memory, cores, instances) |
| `hadoop_conf` | `Dict[str, str]` | Hadoop configuration key-value pairs (e.g., S3/GCS access settings) |
| `executor_path` | `str` | Path to the Python binary for PySpark executors |
| `applications_path` | `str` | Path to the main Spark application file |
| `driver_pod` | `PodTemplate` | Pod template for the Spark driver pod |
| `executor_pod` | `PodTemplate` | Pod template for the Spark executor pods |

### Accessing the Spark session

Inside a Spark task, the `SparkSession` is available through the task context:

```python
from flyte._context import internal_ctx

@spark_env.task
async def my_spark_task() -> float:
    ctx = internal_ctx()
    spark = ctx.data.task_context.data["spark_session"]
    # Use spark as a normal SparkSession
    df = spark.read.parquet("s3://my-bucket/data.parquet")
    return df.count()
```

### Overriding configuration at runtime

You can override Spark configuration for individual task calls using `.override()`:

```python
from copy import deepcopy

updated_config = deepcopy(spark_config)
updated_config.spark_conf["spark.executor.instances"] = "4"

result = await my_spark_task.override(plugin_config=updated_config)()
```

## Example

```python
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-spark"
# ]
# main = "hello_spark_nested"
# params = "3"
# ///

import random
from copy import deepcopy
from operator import add

from flyteplugins.spark.task import Spark

import flyte.remote
from flyte._context import internal_ctx

image = (
    flyte.Image.from_base("apache/spark-py:v3.4.0")
    .clone(name="spark", python_version=(3, 10), registry="ghcr.io/flyteorg")
    .with_pip_packages("flyteplugins-spark", pre=True)
)

task_env = flyte.TaskEnvironment(
    name="get_pi", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)

spark_conf = Spark(
    spark_conf={
        "spark.driver.memory": "3000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
        "spark.kubernetes.file.upload.path": "/opt/spark/work-dir",
        "spark.jars": "https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar,https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.2.2/hadoop-aws-3.2.2.jar,https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar",
    },
)

spark_env = flyte.TaskEnvironment(
    name="spark_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("3000Mi", "5000Mi")),
    plugin_config=spark_conf,
    image=image,
    depends_on=[task_env],
)

def f(_):
    x = random.random() * 2 - 1
    y = random.random() * 2 - 1
    return 1 if x**2 + y**2 <= 1 else 0

@task_env.task
async def get_pi(count: int, partitions: int) -> float:
    return 4.0 * count / partitions

@spark_env.task
async def hello_spark_nested(partitions: int = 3) -> float:
    n = 1 * partitions
    ctx = internal_ctx()
    spark = ctx.data.task_context.data["spark_session"]
    count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add)

    return await get_pi(count, partitions)

@task_env.task
async def spark_overrider(executor_instances: int = 3, partitions: int = 4) -> float:
    updated_spark_conf = deepcopy(spark_conf)
    updated_spark_conf.spark_conf["spark.executor.instances"] = str(executor_instances)
    return await hello_spark_nested.override(plugin_config=updated_spark_conf)(partitions=partitions)

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(hello_spark_nested)
    print(r.name)
    print(r.url)
    r.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/spark/spark_example.py*

## API reference

See the [Spark API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/spark/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb ===

# Weights & Biases

[Weights & Biases](https://wandb.ai) (W&B) is a platform for tracking machine learning experiments, visualizing metrics and optimizing hyperparameters. This plugin integrates W&B with Flyte, enabling you to:

- Automatically initialize W&B runs in your tasks without boilerplate
- Link directly from the Flyte UI to your W&B runs and sweeps
- Share W&B runs across parent and child tasks
- Track distributed training jobs across multiple GPUs and nodes
- Run hyperparameter sweeps with parallel agents

## Installation

```bash
pip install flyteplugins-wandb
```

You also need a W&B API key. Store it as a Flyte secret so your tasks can authenticate with W&B.

## Quick start

Here's a minimal example that logs metrics to W&B from a Flyte task:

```
import flyte

from flyteplugins.wandb import get_wandb_run, wandb_config, wandb_init

env = flyte.TaskEnvironment(
    name="wandb-example",
    image=flyte.Image.from_debian_base(name="wandb-example").with_pip_packages(
        "flyteplugins-wandb"
    ),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

@wandb_init
@env.task
async def train_model() -> str:
    wandb_run = get_wandb_run()

    # Your training code here
    for epoch in range(10):
        loss = 1.0 / (epoch + 1)
        wandb_run.log({"epoch": epoch, "loss": loss})

    return "Training complete"

if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext(
        custom_context=wandb_config(
            project="my-project",
            entity="my-team",
        ),
    ).run(train_model)

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/quick_start.py*

This example demonstrates the core pattern:

1. **Define a task environment** with the plugin installed and your W&B API key as a secret
2. **Decorate your task** with `@wandb_init` (must be the outermost decorator, above `@env.task`)
3. **Access the run** with `get_wandb_run()` to log metrics
4. **Provide configuration** via `wandb_config()` when running the task

The plugin handles calling `wandb.init()` and `wandb.finish()` for you, and automatically adds a link to the W&B run in the Flyte UI.

![UI](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/wandb/ui.png)

## What's next

This integration guide is split into focused sections, depending on how you want to use Weights & Biases with Flyte:

- ****Weights & Biases > Experiments****: Create and manage W&B runs from Flyte tasks.
- ****Weights & Biases > Distributed training****: Track experiments across multi-GPU and multi-node training jobs.
- ****Weights & Biases > Sweeps****: Run hyperparameter searches and manage sweep execution from Flyte tasks.
- ****Weights & Biases > Downloading logs****: Download logs and execution metadata from Weights & Biases.
- ****Weights & Biases > Constraints and best practices****: Learn about limitations, edge cases and recommended patterns.
- ****Weights & Biases > Manual integration****: Use Weights & Biases directly in Flyte tasks without decorators or helpers.

> **📝 Note**
>
> We've included additional examples developed while testing edge cases of the plugin [here](https://github.com/flyteorg/flyte-sdk/tree/main/plugins/wandb/examples).

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/experiments ===

# Experiments

The `@wandb_init` decorator automatically initializes a W&B run when your task executes and finishes it when the task completes. This section covers the different ways to use it.

## Basic usage

Apply `@wandb_init` as the outermost decorator on your task:

```python {hl_lines=1}
@wandb_init
@env.task
async def my_task() -> str:
    run = get_wandb_run()
    run.log({"metric": 42})
    return "done"
```

The decorator:

- Calls `wandb.init()` before your task code runs
- Calls `wandb.finish()` after your task completes (or fails)
- Adds a link to the W&B run in the Flyte UI

You can also use it on synchronous tasks:

```python {hl_lines=[1, 3]}
@wandb_init
@env.task
def my_sync_task() -> str:
    run = get_wandb_run()
    run.log({"metric": 42})
    return "done"
```

## Accessing the run object

Use `get_wandb_run()` to access the current W&B run object:

```python {hl_lines=6}
from flyteplugins.wandb import get_wandb_run

@wandb_init
@env.task
async def train() -> str:
    run = get_wandb_run()

    # Log metrics
    run.log({"loss": 0.5, "accuracy": 0.9})

    # Access run properties
    print(f"Run ID: {run.id}")
    print(f"Run URL: {run.url}")
    print(f"Project: {run.project}")

    # Log configuration
    run.config.update({"learning_rate": 0.001, "batch_size": 32})

    return run.id
```

## Parent-child task relationships

When a parent task calls child tasks, the plugin can share the same W&B run across all of them. This is useful for tracking an entire workflow in a single run.

```python {hl_lines=[1, 9, 16]}
@wandb_init
@env.task
async def child_task(x: int) -> int:
    run = get_wandb_run()
    run.log({"child_metric": x * 2})
    return x * 2

@wandb_init
@env.task
async def parent_task() -> int:
    run = get_wandb_run()
    run.log({"parent_metric": 100})

    # Child task shares the parent's run by default
    result = await child_task(5)

    return result
```

By default (`run_mode="auto"`), child tasks reuse their parent's W&B run. All metrics logged by the parent and children appear in the same run in the W&B UI.

## Run modes

The `run_mode` parameter controls how tasks create or reuse W&B runs. There are three modes:

| Mode             | Behavior                                                                   |
| ---------------- | -------------------------------------------------------------------------- |
| `auto` (default) | Create a new run if no parent run exists, otherwise reuse the parent's run |
| `new`            | Always create a new run, even if a parent run exists                       |
| `shared`         | Always reuse the parent's run (fails if no parent run exists)              |

### Using `run_mode="new"` for independent runs

```python {hl_lines=1}
@wandb_init(run_mode="new")
@env.task
async def independent_child(x: int) -> int:
    run = get_wandb_run()
    # This task gets its own separate run
    run.log({"independent_metric": x})
    return x

@wandb_init
@env.task
async def parent_task() -> str:
    run = get_wandb_run()
    parent_run_id = run.id

    # This child creates its own run
    await independent_child(5)

    # Parent's run is unchanged
    assert run.id == parent_run_id
    return parent_run_id
```

### Using `run_mode="shared"` for explicit sharing

```python {hl_lines=1}
@wandb_init(run_mode="shared")
@env.task
async def must_share_run(x: int) -> int:
    # This task requires a parent run to exist
    # It will fail if called as a top-level task
    run = get_wandb_run()
    run.log({"shared_metric": x})
    return x
```

## Configuration with `wandb_config`

Use `wandb_config()` to configure W&B runs. You can set it at the workflow level or override it for specific tasks, allowing you to provide configuration values at runtime.

### Workflow-level configuration

```python {hl_lines=["5-9"]}
if __name__ == "__main__":
    flyte.init_from_config()

    flyte.with_runcontext(
        custom_context=wandb_config(
            project="my-project",
            entity="my-team",
            tags=["experiment-1", "production"],
            config={"model": "resnet50", "dataset": "imagenet"},
        ),
    ).run(train_task)
```

### Overriding configuration for child tasks

Use `wandb_config()` as a context manager to override settings for specific child task calls:

```python {hl_lines=[8, 12]}
@wandb_init
@env.task
async def parent_task() -> str:
    run = get_wandb_run()
    run.log({"parent_metric": 100})

    # Override tags and config for this child call
    with wandb_config(tags=["special-run"], config={"learning_rate": 0.01}):
        await child_task(10)

    # Override run_mode for this child call
    with wandb_config(run_mode="new"):
        await child_task(20)  # Gets its own run

    return "done"
```

## Using traces with W&B runs

Flyte traces can access the parent task's W&B run without needing the `@wandb_init` decorator. This is useful for helper functions that should log to the same run:

```python {hl_lines=[1, 3]}
@flyte.trace
async def log_validation_metrics(accuracy: float, f1: float):
    run = get_wandb_run()
    if run:
        run.log({"val_accuracy": accuracy, "val_f1": f1})

@wandb_init
@env.task
async def train_and_validate() -> str:
    run = get_wandb_run()

    # Training loop
    for epoch in range(10):
        run.log({"train_loss": 1.0 / (epoch + 1)})

    # Trace logs to the same run
    await log_validation_metrics(accuracy=0.95, f1=0.92)

    return "done"
```

> **📝 Note**
>
> Do not apply `@wandb_init` to traces. Traces automatically access the parent task's run via `get_wandb_run()`.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/distributed_training ===

# Distributed training

When running distributed training jobs, multiple processes run simultaneously across GPUs. The `@wandb_init` decorator automatically detects distributed training environments and coordinates W&B logging across processes.

The plugin:

- Auto-detects distributed context from environment variables (set by launchers like `torchrun`)
- Controls which processes initialize W&B runs based on the `run_mode` and `rank_scope` parameters
- Generates unique run IDs that distinguish between workers and ranks
- Adds links to W&B runs in the Flyte UI

## Quick start

Here's a minimal single-node example that logs metrics from a distributed training task. By default (`run_mode="auto"`, `rank_scope="global"`), only rank 0 logs to W&B:

```
import flyte
import torch
import torch.distributed
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import get_wandb_run, wandb_config, wandb_init

image = flyte.Image.from_debian_base(name="torch-wandb").with_pip_packages(
    "flyteplugins-wandb", "flyteplugins-pytorch"
)

env = flyte.TaskEnvironment(
    name="distributed_env",
    image=image,
    resources=flyte.Resources(gpu="A100:2"),
    plugin_config=Elastic(nproc_per_node=2, nnodes=1),
    secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
)

@wandb_init
@env.task
def train() -> float:
    torch.distributed.init_process_group("nccl")

    # Only rank 0 gets a W&B run object; others get None
    run = get_wandb_run()

    # Simulate training
    for step in range(100):
        loss = 1.0 / (step + 1)

        # Safe to call on all ranks - only rank 0 actually logs
        if run:
            run.log({"loss": loss, "step": step})

    torch.distributed.destroy_process_group()
    return loss

if __name__ == "__main__":
    flyte.init_from_config()
    flyte.with_runcontext(
        custom_context=wandb_config(project="my-project", entity="my-team")
    ).run(train)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/distributed_training_quick_start.py*

A few things to note:

1. Use the `Elastic` plugin to configure distributed training (number of processes, nodes)
2. Apply `@wandb_init` as the outermost decorator
3. Check if `run` is not None before logging - only the primary rank has a run object in `auto` mode

> **📝 Note**
>
> The `if run:` check is always safe regardless of run mode. In `shared` and `new` modes all ranks get a run object, but the check doesn't hurt and keeps your code portable across modes.

![Single-node auto](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/wandb/single_node_auto_flyte.png)

## Run modes in distributed training

The `run_mode` parameter controls how W&B runs are created across distributed processes. The behavior differs between single-node (one machine, multiple GPUs) and multi-node (multiple machines) setups.

### Single-node behavior

| Mode             | Which ranks log       | Result                                 |
| ---------------- | --------------------- | -------------------------------------- |
| `auto` (default) | Only rank 0           | 1 W&B run                              |
| `shared`         | All ranks to same run | 1 W&B run with metrics labeled by rank |
| `new`            | Each rank separately  | N W&B runs (grouped in UI)             |

### Multi-node behavior

For multi-node training, the `rank_scope` parameter controls the granularity of W&B runs:

- **`global`** (default): Treat all workers as one unit
- **`worker`**: Treat each worker/node independently

The combination of `run_mode` and `rank_scope` determines logging behavior:

| `run_mode` | `rank_scope` | Who initializes W&B    | W&B Runs | Grouping |
| ---------- | ------------ | ---------------------- | -------- | -------- |
| `auto`     | `global`     | Global rank 0 only     | 1        | -        |
| `auto`     | `worker`     | Local rank 0 per worker | N        | -        |
| `shared`   | `global`     | All ranks (shared globally) | 1        | -        |
| `shared`   | `worker`     | All ranks (shared per worker) | N        | -        |
| `new`      | `global`     | All ranks              | N × M    | 1 group  |
| `new`      | `worker`     | All ranks              | N × M    | N groups |

Where `N` = number of workers/nodes, `M` = processes per worker.

### Choosing run mode and rank scope

- **`auto`** (recommended): Use when you want clean dashboards with minimal runs. Most metrics (loss, accuracy) are the same across ranks after gradient synchronization, so logging from one rank is sufficient.
- **`shared`**: Use when you need to compare metrics across ranks in a single view. Each rank's metrics are labeled with an `x_label` identifier. Useful for debugging load imbalance or per-GPU throughput.
- **`new`**: Use when you need completely separate runs per GPU, for example to track GPU-specific metrics or compare training dynamics across devices.

For multi-node training:
- Use **`rank_scope="global"`** (default) for most cases. A single consolidated run across all nodes is sufficient since metrics like loss and accuracy converge after gradient synchronization.
- Use **`rank_scope="worker"`** for debugging and per-node analysis. This is useful when you need to inspect data distribution across nodes, compare predictions from different workers, or track metrics on individual batches outside the main node.

## Single-node multi-GPU

For single-node distributed training, configure the `Elastic` plugin with `nnodes=1` and set `nproc_per_node` to your GPU count.

### Basic example with `auto` mode

```python {hl_lines=["6-7", 13, 18, 30]}
import os

import torch
import torch.distributed
import flyte
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import wandb_init, get_wandb_run

env = flyte.TaskEnvironment(
    name="single_node_env",
    image=image,
    resources=flyte.Resources(gpu="A100:4"),
    plugin_config=Elastic(nproc_per_node=4, nnodes=1),
    secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
)

@wandb_init # run_mode="auto" (default)
@env.task
def train_single_node() -> float:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)

    run = get_wandb_run()

    # Training loop - only rank 0 logs
    for epoch in range(10):
        loss = train_epoch(model, dataloader, device)

        if run:
            run.log({"epoch": epoch, "loss": loss})

    torch.distributed.destroy_process_group()
    return loss
```

### Using `shared` mode for per-rank metrics

When you need to see metrics from all GPUs in a single run, use `run_mode="shared"`:

```python {hl_lines=[3, 13, 19]}
import os

@wandb_init(run_mode="shared")
@env.task
def train_with_per_gpu_metrics() -> float:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)

    # In shared mode, all ranks get a run object
    run = get_wandb_run()

    for step in range(1000):
        loss, throughput = train_step(model, batch, device)

        # Each rank logs with automatic x_label identification
        if run:
            run.log({
                "loss": loss,
                "throughput_samples_per_sec": throughput,
                "gpu_memory_used": torch.cuda.memory_allocated(device),
            })

    torch.distributed.destroy_process_group()
    return loss
```

![Single-node shared](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/wandb/single_node_shared_flyte.png)

In the W&B UI, metrics from each rank appear with distinct labels, allowing you to compare GPU utilization and throughput across devices.

![Single-node shared W&B UI](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/wandb/single_node_shared_wandb.png)

### Using `new` mode for per-rank runs

When you need completely separate W&B runs for each GPU, use `run_mode="new"`. Each rank gets its own run, and runs are grouped together in the W&B UI:

```python {hl_lines=[1, "11-12"]}
@wandb_init(run_mode="new")  # Each rank gets its own run
@env.task
def train_per_rank() -> float:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    # ...

    # Each rank has its own W&B run
    run = get_wandb_run()

    # Run IDs: {base}-rank-{rank}
    # All runs are grouped under {base} in W&B UI
    run.log({"train/loss": loss.item(), "rank": rank})
    # ...
```

With `run_mode="new"`:

- Each rank creates its own W&B run
- Run IDs follow the pattern `{run_name}-{action_name}-rank-{rank}`
- All runs are grouped together in the W&B UI for comparison

## Multi-node training with `Elastic`

For multi-node distributed training, set `nnodes` to your node count. The `rank_scope` parameter controls whether you get a single W&B run across all nodes (`global`) or one run per node (`worker`).

### Global scope (default): Single run across all nodes

With `run_mode="auto"` and `rank_scope="global"` (both defaults), only global rank 0 initializes W&B, resulting in a single run for the entire distributed job:

```python {hl_lines=["11-12", "27-30", "35", "59-60", "95-98"]}
import os

import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

import flyte
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import wandb_init, wandb_config, get_wandb_run

image = flyte.Image.from_debian_base(name="torch-wandb").with_pip_packages(
    "flyteplugins-wandb", "flyteplugins-pytorch", pre=True
)

multi_node_env = flyte.TaskEnvironment(
    name="multi_node_env",
    image=image,
    resources=flyte.Resources(
        cpu=(1, 2),
        memory=("1Gi", "10Gi"),
        gpu="A100:4",
        shm="auto",
    ),
    plugin_config=Elastic(
        nproc_per_node=4,  # GPUs per node
        nnodes=2,          # Number of nodes
    ),
    secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
)

@wandb_init  # rank_scope="global" by default → 1 run total
@multi_node_env.task
def train_multi_node(epochs: int, batch_size: int) -> float:
    torch.distributed.init_process_group("nccl")

    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)

    # Model with DDP
    model = MyModel().to(device)
    model = DDP(model, device_ids=[local_rank])

    # Distributed data loading
    dataset = MyDataset()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # Only global rank 0 gets a W&B run
    run = get_wandb_run()

    for epoch in range(epochs):
        sampler.set_epoch(epoch)
        model.train()

        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            if run and batch_idx % 100 == 0:
                run.log({
                    "train/loss": loss.item(),
                    "train/epoch": epoch,
                    "train/batch": batch_idx,
                })

        if run:
            run.log({"train/epoch_complete": epoch})

    # Barrier ensures all ranks finish before cleanup
    torch.distributed.barrier()
    torch.distributed.destroy_process_group()

    return loss.item()

if __name__ == "__main__":
    flyte.init_from_config()
    flyte.with_runcontext(
        custom_context=wandb_config(
            project="multi-node-training",
            tags=["distributed", "multi-node"],
        )
    ).run(train_multi_node, epochs=10, batch_size=32)
```

With this configuration:

- Two nodes run the task, each with 4 GPUs (8 total processes)
- Only global rank 0 creates a W&B run
- Run ID follows the pattern `{run_name}-{action_name}`
- The Flyte UI shows a single link to the W&B run

### Worker scope: One run per node

Use `rank_scope="worker"` when you want each node to have its own W&B run for per-node analysis:

```python {hl_lines=[1, 8]}
@wandb_init(rank_scope="worker")  # 1 run per worker/node
@multi_node_env.task
def train_per_worker(epochs: int, batch_size: int) -> float:
    torch.distributed.init_process_group("nccl")
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    # ...

    # Local rank 0 of each worker gets a W&B run
    run = get_wandb_run()

    if run:
        # Each worker logs to its own run
        run.log({"train/loss": loss.item()})
    # ...
```

With `run_mode="auto"`, `rank_scope="worker"`:

- Each node's local rank 0 creates a W&B run
- Run IDs follow the pattern `{run_name}-{action_name}-worker-{worker_index}`
- The Flyte UI shows links to each worker's W&B run

![Multi-node](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/wandb/multi_node.png)

### Shared mode: All ranks log to the same run

Use `run_mode="shared"` when you need metrics from all ranks in a single view. Each rank's metrics are labeled with an `x_label` identifier.

#### Shared + global scope (1 run total)

```python {hl_lines=[1, 7]}
@wandb_init(run_mode="shared")  # All ranks log to 1 shared run
@multi_node_env.task
def train_shared_global() -> float:
    torch.distributed.init_process_group("nccl")
    # ...

    # All ranks get a run object, all log to the same run
    run = get_wandb_run()

    # Each rank logs with automatic x_label identification
    run.log({"train/loss": loss.item(), "rank": rank})
    # ...
```

#### Shared + worker scope (N runs, 1 per node)

```python {hl_lines=[1, 7, 10]}
@wandb_init(run_mode="shared", rank_scope="worker")  # 1 shared run per worker
@multi_node_env.task
def train_shared_worker() -> float:
    torch.distributed.init_process_group("nccl")
    # ...

    # All ranks get a run object, grouped by worker
    run = get_wandb_run()

    # Ranks on the same worker share a run
    run.log({"train/loss": loss.item(), "local_rank": local_rank})
    # ...
```

### New mode: Separate run per rank

Use `run_mode="new"` when you need completely separate runs per GPU. Runs are grouped in the W&B UI for easy comparison.

#### New + global scope (N×M runs, 1 group)

```python {hl_lines=[1, 7, 10]}
@wandb_init(run_mode="new")  # Each rank gets its own run, all in 1 group
@multi_node_env.task
def train_new_global() -> float:
    torch.distributed.init_process_group("nccl")
    # ...

    # Each rank has its own run
    run = get_wandb_run()

    # Run IDs: {base}-rank-{global_rank}
    run.log({"train/loss": loss.item()})
    # ...
```

#### New + worker scope (N×M runs, N groups)

```python {hl_lines=[1, 7, 10]}
@wandb_init(run_mode="new", rank_scope="worker")  # Each rank gets own run, grouped per worker
@multi_node_env.task
def train_new_worker() -> float:
    torch.distributed.init_process_group("nccl")
    # ...

    # Each rank has its own run, grouped by worker
    run = get_wandb_run()

    # Run IDs: {base}-worker-{idx}-rank-{local_rank}
    run.log({"train/loss": loss.item()})
    # ...
```

## How it works

The plugin automatically detects distributed training by checking environment variables set by distributed launchers like `torchrun`:

| Environment variable | Description                                              |
| -------------------- | -------------------------------------------------------- |
| `RANK`               | Global rank across all processes                         |
| `WORLD_SIZE`         | Total number of processes                                |
| `LOCAL_RANK`         | Rank within the current node                             |
| `LOCAL_WORLD_SIZE`   | Number of processes on the current node                  |
| `GROUP_RANK`         | Node/worker index (0 for first node, 1 for second, etc.) |

When these variables are present, the plugin:

1. **Determines which ranks should initialize W&B** based on `run_mode` and `rank_scope`
2. **Generates unique run IDs** that include worker and rank information
4. **Creates UI links** for each W&B run (single link with `rank_scope="global"`, one per worker with `rank_scope="worker"`)

The plugin automatically adapts to your training setup, eliminating the need for manual distributed configuration.

### Run ID patterns

| Scenario                     | Run ID Pattern                                | Group                    |
| ---------------------------- | --------------------------------------------- | ------------------------ |
| Single-node auto/shared      | `{base}`                                      | -                        |
| Single-node new              | `{base}-rank-{rank}`                          | `{base}`                 |
| Multi-node auto/shared (global) | `{base}`                                   | -                        |
| Multi-node auto/shared (worker) | `{base}-worker-{idx}`                      | -                        |
| Multi-node new (global)      | `{base}-rank-{global_rank}`                   | `{base}`                 |
| Multi-node new (worker)      | `{base}-worker-{idx}-rank-{local_rank}`       | `{base}-worker-{idx}`    |

Where `{base}` = `{run_name}-{action_name}`

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/sweeps ===

# Sweeps

W&B sweeps automate hyperparameter optimization by running multiple trials with different parameter combinations. The `@wandb_sweep` decorator creates a sweep and makes it easy to run trials in parallel using Flyte's distributed execution.

## Creating a sweep

Use `@wandb_sweep` to create a W&B sweep when the task executes:

```
import flyte
import wandb
from flyteplugins.wandb import (
    get_wandb_sweep_id,
    wandb_config,
    wandb_init,
    wandb_sweep,
    wandb_sweep_config,
)

env = flyte.TaskEnvironment(
    name="wandb-example",
    image=flyte.Image.from_debian_base(name="wandb-example").with_pip_packages(
        "flyteplugins-wandb"
    ),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

@wandb_init
def objective():
    """Objective function that W&B calls for each trial."""
    wandb_run = wandb.run
    config = wandb_run.config

    # Simulate training with hyperparameters from the sweep
    for epoch in range(config.epochs):
        loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
        wandb_run.log({"epoch": epoch, "loss": loss})

@wandb_sweep
@env.task
async def run_sweep() -> str:
    sweep_id = get_wandb_sweep_id()

    # Run 10 trials
    wandb.agent(sweep_id, function=objective, count=10)

    return sweep_id

if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext(
        custom_context={
            **wandb_config(project="my-project", entity="my-team"),
            **wandb_sweep_config(
                method="random",
                metric={"name": "loss", "goal": "minimize"},
                parameters={
                    "learning_rate": {"min": 0.0001, "max": 0.1},
                    "batch_size": {"values": [16, 32, 64, 128]},
                    "epochs": {"values": [5, 10, 20]},
                },
            ),
        },
    ).run(run_sweep)

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/sweep.py*

The `@wandb_sweep` decorator:

- Creates a W&B sweep when the task starts
- Makes the sweep ID available via `get_wandb_sweep_id()`
- Adds a link to the main sweeps page in the Flyte UI

Use `wandb_sweep_config()` to define the sweep parameters. This is passed to W&B's sweep API.

> **📝 Note**
>
> Random and Bayesian searches run indefinitely, and the sweep remains in the `Running` state until you stop it.
> You can stop a running sweep from the Weights & Biases UI or from the command line.

## Running parallel agents

Flyte's distributed execution makes it easy to run multiple sweep agents in parallel, each on its own compute resources:

```
import asyncio
from datetime import timedelta

import flyte
import wandb
from flyteplugins.wandb import (
    get_wandb_sweep_id,
    wandb_config,
    wandb_init,
    wandb_sweep,
    wandb_sweep_config,
    get_wandb_context,
)

env = flyte.TaskEnvironment(
    name="wandb-parallel-sweep-example",
    image=flyte.Image.from_debian_base(
        name="wandb-parallel-sweep-example"
    ).with_pip_packages("flyteplugins-wandb"),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

@wandb_init
def objective():
    wandb_run = wandb.run
    config = wandb_run.config

    for epoch in range(config.epochs):
        loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
        wandb_run.log({"epoch": epoch, "loss": loss})

@wandb_sweep
@env.task
async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int:
    """Single agent that runs a subset of trials."""
    wandb.agent(
        sweep_id, function=objective, count=count, project=get_wandb_context().project
    )
    return agent_id

@wandb_sweep
@env.task
async def run_parallel_sweep(total_trials: int = 20, trials_per_agent: int = 5) -> str:
    """Orchestrate multiple agents running in parallel."""
    sweep_id = get_wandb_sweep_id()

    num_agents = (total_trials + trials_per_agent - 1) // trials_per_agent

    # Launch agents in parallel, each with its own resources
    agent_tasks = [
        sweep_agent.override(
            resources=flyte.Resources(cpu="2", memory="4Gi"),
            retries=3,
            timeout=timedelta(minutes=30),
        )(agent_id=i, sweep_id=sweep_id, count=trials_per_agent)
        for i in range(num_agents)
    ]

    await asyncio.gather(*agent_tasks)
    return sweep_id

if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext(
        custom_context={
            **wandb_config(project="my-project", entity="my-team"),
            **wandb_sweep_config(
                method="random",
                metric={"name": "loss", "goal": "minimize"},
                parameters={
                    "learning_rate": {"min": 0.0001, "max": 0.1},
                    "batch_size": {"values": [16, 32, 64]},
                    "epochs": {"values": [5, 10, 20]},
                },
            ),
        },
    ).run(
        run_parallel_sweep,
        total_trials=20,
        trials_per_agent=5,
    )

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/parallel_sweep.py*

This pattern provides:

- **Distributed execution**: Each agent runs on separate compute nodes
- **Resource allocation**: Specify CPU, memory, and GPU per agent
- **Fault tolerance**: Failed agents can retry without affecting others
- **Timeout protection**: Prevent runaway trials

> **📝 Note**
>
> `run_parallel_sweep` links to the main Weights & Biases sweeps page and `sweep_agent` links to the specific sweep URL because we cannot determine the sweep ID at link rendering time.

![Sweep](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/wandb/sweep.png)

## Writing objective functions

The objective function is called by `wandb.agent()` for each trial. It must be a regular Python function decorated with `@wandb_init`:

```python {hl_lines=["1-2", "5-6"]}
@wandb_init
def objective():
    """Objective function for sweep trials."""
    # Access hyperparameters from wandb.run.config
    run = wandb.run
    config = run.config

    # Your training code
    model = create_model(
        learning_rate=config.learning_rate,
        hidden_size=config.hidden_size,
    )

    for epoch in range(config.epochs):
        train_loss = train_epoch(model)
        val_loss = validate(model)

        # Log metrics - W&B tracks these for the sweep
        run.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "val_loss": val_loss,
        })

    # The final val_loss is used by the sweep to rank trials
```

Key points:

- Use `@wandb_init` on the objective function (not `@env.task`)
- Access hyperparameters via `wandb.run.config` (not `get_wandb_run()` since this is outside Flyte context)
- Log the metric specified in `wandb_sweep_config(metric=...)` so the sweep can optimize it
- The function is called multiple times by `wandb.agent()`, once per trial

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/downloading_logs ===

# Downloading logs

This integration enables downloading Weights & Biases run data, including metrics history, summary data, and synced files.

## Automatic download

Set `download_logs=True` to automatically download run data after your task completes:

```python {hl_lines=1}
@wandb_init(download_logs=True)
@env.task
async def train_with_download():
    run = get_wandb_run()

    for epoch in range(10):
        run.log({"loss": 1.0 / (epoch + 1)})

    return run.id
```

The downloaded data is traced by Flyte and appears as a `Dir` output in the Flyte UI. Downloaded files include:

- `summary.json`: Final summary metrics
- `metrics_history.json`: Step-by-step metrics history
- Any files synced by W&B (`requirements.txt`, `wandb_metadata.json`, etc.)

You can also set `download_logs=True` in `wandb_config()`:

```python {hl_lines=5}
flyte.with_runcontext(
    custom_context=wandb_config(
        project="my-project",
        entity="my-team",
        download_logs=True,
    ),
).run(train_task)
```

![Logs](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/wandb/logs.png)

For sweeps, set `download_logs=True` on `@wandb_sweep` or `wandb_sweep_config()` to download all trial data:

```python {hl_lines=1}
@wandb_sweep(download_logs=True)
@env.task
async def run_sweep():
    sweep_id = get_wandb_sweep_id()
    wandb.agent(sweep_id, function=objective, count=10)
    return sweep_id
```

![Sweep Logs](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/wandb/sweep_logs.png)

## Accessing run directories during execution

Use `get_wandb_run_dir()` to access the local W&B run directory during task execution. This is useful for writing custom files that get synced to W&B:

```python {hl_lines=[1, 7, "18-19"]}
from flyteplugins.wandb import get_wandb_run_dir

@wandb_init
@env.task
def train_with_artifacts():
    run = get_wandb_run()
    local_dir = get_wandb_run_dir()

    # Train your model
    for epoch in range(10):
        run.log({"loss": 1.0 / (epoch + 1)})

    # Save model checkpoint to the run directory
    model_path = f"{local_dir}/model_checkpoint.pt"
    torch.save(model.state_dict(), model_path)

    # Save custom metrics file
    with open(f"{local_dir}/custom_metrics.json", "w") as f:
        json.dump({"final_accuracy": 0.95}, f)

    return run.id
```

Files written to the run directory are automatically synced to W&B and can be accessed later via the W&B UI or by setting `download_logs=True`.

> **📝 Note**
>
> `get_wandb_run_dir()` accesses the local directory without making network calls. Files written here may have a brief delay before appearing in the W&B cloud.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/constraints_and_best_practices ===

# Constraints and best practices

## Decorator ordering

`@wandb_init` and `@wandb_sweep` must be the **outermost decorators**, applied after `@env.task`:

```python
# Correct
@wandb_init
@env.task
async def my_task():
    ...

# Incorrect - will not work
@env.task
@wandb_init
async def my_task():
    ...
```

## Traces cannot use decorators

Do not apply `@wandb_init` to traces. Traces automatically access the parent task's run via `get_wandb_run()`:

```python
# Correct
@flyte.trace
async def my_trace():
    run = get_wandb_run()
    if run:
        run.log({"metric": 42})

# Incorrect - don't decorate traces
@wandb_init
@flyte.trace
async def my_trace():
    ...
```

## Maximum sweep agents

[W&B limits sweeps to a maximum of 20 concurrent agents](https://docs.wandb.ai/models/sweeps/existing-project#3-launch-agents).

## Configuration priority

Configuration is merged with the following priority (highest to lowest):

1. Decorator parameters (`@wandb_init(project="...")`)
2. Context manager (`with wandb_config(...)`)
3. Workflow-level context (`flyte.with_runcontext(custom_context=wandb_config(...))`)
4. Auto-generated values (run ID from Flyte context)

## Run ID generation

When no explicit `id` is provided, the plugin generates run IDs using the pattern:

```
{run_name}-{action_name}
```

This ensures unique, predictable IDs that can be matched between the `Wandb` link class and manual `wandb.init()` calls.

## Sync delay for local files

Files written to the run directory (via `get_wandb_run_dir()`) are synced to W&B asynchronously. There may be a brief delay before they appear in the W&B cloud or can be downloaded via `download_wandb_run_dir()`.

## Shared run mode requirements

When using `run_mode="shared"`, the task requires a parent task to have already created a W&B run. Calling a task with `run_mode="shared"` as a top-level task will fail.

## Objective functions for sweeps

Objective functions passed to `wandb.agent()` should:

- Be regular Python functions (not Flyte tasks)
- Be decorated with `@wandb_init`
- Access hyperparameters via `wandb.run.config` (not `get_wandb_run()`)
- Log the metric specified in `wandb_sweep_config(metric=...)` so the sweep can optimize it

## Error handling

The plugin raises standard exceptions:

- `RuntimeError`: When `download_wandb_run_dir()` is called without a run ID and no active run exists
- `wandb.errors.AuthenticationError`: When `WANDB_API_KEY` is not set or invalid
- `wandb.errors.CommError`: When a run cannot be found in the W&B cloud

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/manual ===

# Manual integration

If you need more control over W&B initialization, you can use the `Wandb` and `WandbSweep` link classes directly instead of the decorators. This lets you call `wandb.init()` and `wandb.finish()` yourself while still getting automatic links in the Flyte UI.

## Using the Wandb link class

Add a `Wandb` link to your task to generate a link to the W&B run in the Flyte UI:

```
import flyte
import wandb
from flyteplugins.wandb import Wandb

env = flyte.TaskEnvironment(
    name="wandb-manual-init-example",
    image=flyte.Image.from_debian_base(
        name="wandb-manual-init-example"
    ).with_pip_packages("flyteplugins-wandb"),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

@env.task(
    links=(
        Wandb(
            project="my-project",
            entity="my-team",
            run_mode="new",
            # No id parameter - link will auto-generate from run_name-action_name
        ),
    )
)
async def train_model(learning_rate: float) -> str:
    ctx = flyte.ctx()

    # Generate run ID matching the link's auto-generated ID
    run_id = f"{ctx.action.run_name}-{ctx.action.name}"

    # Manually initialize W&B
    wandb_run = wandb.init(
        project="my-project",
        entity="my-team",
        id=run_id,
        config={"learning_rate": learning_rate},
    )

    # Your training code
    for epoch in range(10):
        loss = 1.0 / (learning_rate * (epoch + 1))
        wandb_run.log({"epoch": epoch, "loss": loss})

    # Manually finish the run
    wandb_run.finish()

    return wandb_run.id

if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext().run(
        train_model,
        learning_rate=0.01,
    )

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/init_manual.py*

### With a custom run ID

If you want to use your own run ID, specify it in both the link and the `wandb.init()` call:

```python {hl_lines=[6, 14]}
@env.task(
    links=(
        Wandb(
            project="my-project",
            entity="my-team",
            id="my-custom-run-id",
        ),
    )
)
async def train_with_custom_id() -> str:
    run = wandb.init(
        project="my-project",
        entity="my-team",
        id="my-custom-run-id",  # Must match the link's ID
        resume="allow",
    )

    # Training code...
    run.finish()
    return run.id
```

### Adding links at runtime with override

You can also add links when calling a task using `.override()`:

```python {hl_lines=9}
@env.task
async def train_model(learning_rate: float) -> str:
    # ... training code with manual wandb.init() ...
    return run.id

# Add link when running the task
result = await train_model.override(
    links=(Wandb(project="my-project", entity="my-team", run_mode="new"),)
)(learning_rate=0.01)
```

## Using the `WandbSweep` link class

Use `WandbSweep` to add a link to a W&B sweep:

```
import flyte
import wandb
from flyteplugins.wandb import WandbSweep

env = flyte.TaskEnvironment(
    name="wandb-manual-sweep-example",
    image=flyte.Image.from_debian_base(
        name="wandb-manual-sweep-example"
    ).with_pip_packages("flyteplugins-wandb"),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

def objective():
    with wandb.init(project="my-project", entity="my-team") as wandb_run:
        config = wandb_run.config

        for epoch in range(config.epochs):
            loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
            wandb_run.log({"epoch": epoch, "loss": loss})

@env.task(
    links=(
        WandbSweep(
            project="my-project",
            entity="my-team",
        ),
    )
)
async def manual_sweep() -> str:
    # Manually create the sweep
    sweep_config = {
        "method": "random",
        "metric": {"name": "loss", "goal": "minimize"},
        "parameters": {
            "learning_rate": {"min": 0.0001, "max": 0.1},
            "batch_size": {"values": [16, 32, 64]},
            "epochs": {"value": 10},
        },
    }

    sweep_id = wandb.sweep(sweep_config, project="my-project", entity="my-team")

    # Run the sweep
    wandb.agent(sweep_id, function=objective, count=10, project="my-project")

    return sweep_id

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.with_runcontext().run(manual_sweep)

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/sweep_manual.py*

The link will point to the project's sweeps page. If you have the sweep ID, you can specify it in the link:

```python {hl_lines=6}
@env.task(
    links=(
        WandbSweep(
            project="my-project",
            entity="my-team",
            id="known-sweep-id",
        ),
    )
)
async def resume_sweep() -> str:
    # Resume an existing sweep
    wandb.agent("known-sweep-id", function=objective, count=10)
    return "known-sweep-id"
```

=== PAGE: https://www.union.ai/docs/v2/union/integrations/codegen ===

# Code generation

The code generation plugin turns natural-language prompts into tested, production-ready Python code.

You describe what the code should do, along with sample data, schema definitions, constraints, and typed inputs/outputs, and the plugin handles the rest: generating code, writing tests, building an isolated [code sandbox](/docs/v2/union//user-guide/sandboxing/code-sandboxing) with the right dependencies, running the tests, diagnosing failures, and iterating until everything passes. The result is a validated script you can execute against real data or deploy as a reusable Flyte task.

## Installation

```bash
pip install flyteplugins-codegen

# For Agent mode (Claude-only)
pip install flyteplugins-codegen[agent]
```

## Quick start

```python{hl_lines=[3, 4, 6, 12, 14, "20-25"]}
import flyte
from flyte.io import File
from flyte.sandbox import sandbox_environment
from flyteplugins.codegen import AutoCoderAgent

agent = AutoCoderAgent(model="gpt-4.1", name="summarize-sales")

env = flyte.TaskEnvironment(
    name="my-env",
    secrets=[flyte.Secret(key="openai_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_debian_base().with_pip_packages(
        "flyteplugins-codegen",
    ),
    depends_on=[sandbox_environment],
)

@env.task
async def process_data(csv_file: File) -> tuple[float, int, int]:
    result = await agent.generate.aio(
        prompt="Read the CSV and compute total_revenue, total_units and row_count.",
        samples={"sales": csv_file},
        outputs={"total_revenue": float, "total_units": int, "row_count": int},
    )
    return await result.run.aio()
```

The `depends_on=[sandbox_environment]` declaration is required. It ensures the sandbox runtime is available when dynamically-created sandboxes execute.

![Sandbox](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/codegen/sandbox.png)

## Two execution backends

The plugin supports two backends for generating and validating code. Both share the same `AutoCoderAgent` interface and produce the same `CodeGenEvalResult`.

### LiteLLM (default)

Uses structured-output LLM calls to generate code, detect packages, build sandbox images, run tests, diagnose failures, and iterate. Works with any model that supports structured outputs (GPT-4, Claude, Gemini, etc. via LiteLLM).

```python{hl_lines=[1, 3]}
agent = AutoCoderAgent(
    name="my-task",
    model="gpt-4.1",
    max_iterations=10,
)
```

The LiteLLM backend follows a fixed pipeline:

```mermaid
flowchart TD
    A["prompt + samples"] --> B["generate_plan"]
    B --> C["generate_code"]
    C --> D["detect_packages"]
    D --> E["build_image"]
    E --> F{skip_tests?}
    F -- yes --> G["return result"]
    F -- no --> H["generate_tests"]
    H --> I["execute_tests"]
    I --> J{pass?}
    J -- yes --> G
    J -- no --> K["diagnose_error"]
    K --> L{error type?}
    L -- "logic error" --> M["regenerate code"]
    L -- "environment error" --> N["add packages, rebuild image"]
    L -- "test error" --> O["fix test expectations"]
    M --> I
    N --> I
    O --> I
```

The loop continues until tests pass or `max_iterations` is reached.

![LiteLLM](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/codegen/litellm.png)

### Agent (Claude)

Uses the Claude Agent SDK to autonomously generate, test, and fix code. The agent has access to `Bash`, `Read`, `Write`, and `Edit` tools and decides what to do at each step. Test execution commands (`pytest`) are intercepted and run inside isolated sandboxes.

```python{hl_lines=["3-4"]}
agent = AutoCoderAgent(
    name="my-task",
    model="claude-sonnet-4-5-20250929",
    backend="claude",
)
```

> [!NOTE]
> Agent mode requires `ANTHROPIC_API_KEY` as a Flyte secret and is Claude-only.

**Key differences from LiteLLM:**

|                       | LiteLLM                           | Agent                                          |
| --------------------- | --------------------------------- | ---------------------------------------------- |
| **Execution**         | Fixed generate-test-fix pipeline  | Autonomous agent decides actions               |
| **Model support**     | Any model with structured outputs | Claude only                                    |
| **Iteration control** | `max_iterations`                  | `agent_max_turns`                              |
| **Test execution**    | Direct sandbox execution          | `pytest` commands intercepted via hooks        |
| **Tool safety**       | N/A                               | Commands classified as safe/denied/intercepted |
| **Observability**     | Logs + token counts               | Full tool call tracing in Flyte UI             |

In Agent mode, Bash commands are classified before execution:

- **Safe** (`ls`, `cat`, `grep`, `head`, etc.) — allowed to run directly
- **Intercepted** (`pytest`) — routed to sandbox execution
- **Denied** (`apt`, `pip install`, `curl`, etc.) — blocked for safety

## Providing data

### Sample data

Pass sample data via `samples` as `File` objects or pandas `DataFrame`s. The plugin automatically:

1. Converts DataFrames to CSV files
2. Infers [Pandera](https://pandera.readthedocs.io/) schemas from the data — column types, nullability
3. Parses natural-language `constraints` into Pandera checks (e.g., `"quantity must be positive"` becomes `pa.Check.gt(0)`)
4. Extracts data context — column statistics, distributions, patterns, sample rows
5. Injects all of this into the LLM prompt so the generated code is aware of the exact data structure

Pandera is used purely for prompt enrichment, not runtime validation. The generated code does not import Pandera — it benefits from the LLM knowing the precise data structure. The generated schemas are stored on `result.generated_schemas` for inspection.

```python{hl_lines=[3]}
result = await agent.generate.aio(
    prompt="Clean and validate the data, remove duplicates",
    samples={"orders": orders_df, "products": products_file},
    constraints=["quantity must be positive", "price between 0 and 10000"],
    outputs={"cleaned_orders": File},
)
```

### Schema and constraints

Use `schema` to provide free-form context about data formats or target structures (e.g., a database schema). Use `constraints` to declare business rules that the generated code must respect:

```python{hl_lines=["4-17"]}
result = await agent.generate.aio(
    prompt=prompt,
    samples={"readings": sensor_df},
    schema="""Output JSON schema for report_json:
    {
        "sensor_id": str,
        "avg_temp": float,
        "min_temp": float,
        "max_temp": float,
        "avg_humidity": float,
    }
    """,
    constraints=[
        "Temperature values must be between -40 and 60 Celsius",
        "Humidity values must be between 0 and 100 percent",
        "Output report must have one row per unique sensor_id",
    ],
    outputs={
        "report_json": str,
        "total_anomalies": int,
    },
)
```

![Pandera Constraints](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/codegen/pandera_constraints.png)

### Inputs and outputs

Declare `inputs` for non-sample arguments (e.g., thresholds, flags) and `outputs` for the expected result types.

Supported output types: `str`, `int`, `float`, `bool`, `datetime.datetime`, `datetime.timedelta`, `File`.

Sample entries are automatically added as `File` inputs — you do not need to redeclare them.

```python{hl_lines=[4, 5]}
result = await agent.generate.aio(
    prompt="Filter transactions above the threshold",
    samples={"transactions": tx_file},
    inputs={"threshold": float, "include_pending": bool},
    outputs={"filtered": File, "count": int},
)
```

## Running generated code

`agent.generate()` returns a `CodeGenEvalResult`. If `result.success` is `True`, the generated code passed all tests and you can execute it against real data. If `max_iterations` (LiteLLM) or `agent_max_turns` (Agent) is reached without tests passing, `result.success` is `False` and `result.error` contains the failure details.

Both `run()` and `as_task()` return output values as a tuple in the order declared in `outputs`. If there is a single output, the value is returned directly (not wrapped in a tuple).

### One-shot execution with `result.run()`

Runs the generated code in a sandbox. If samples were provided during `generate()`, they are used as default inputs.

```python
# Use sample data as defaults
total_revenue, total_units, count = await result.run.aio()

# Override specific inputs
total_revenue, total_units, count = await result.run.aio(threshold=0.5)

# Sync version
total_revenue, total_units, count = result.run()
```

`result.run()` accepts optional configuration:

```python{hl_lines=["4-6"]}
total_revenue, total_units, count = await result.run.aio(
    name="execute-on-data",
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    retries=2,
    timeout=600,
    cache="auto",
)
```

### Reusable task with `result.as_task()`

Creates a callable sandbox task from the generated code. Useful when you want to run the same generated code against different data.

```python{hl_lines=[1, "6-7", "9-10"]}
task = result.as_task(
    name="run-sensor-analysis",
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

# Call with sample defaults
report, total_anomalies = await task.aio()

# Call with different data
report, total_anomalies = await task.aio(readings=new_data_file)
```

## Error diagnosis

The LiteLLM backend classifies test failures into three categories and applies targeted fixes:

| Error type    | Meaning                       | Action                                           |
| ------------- | ----------------------------- | ------------------------------------------------ |
| `logic`       | Bug in the generated code     | Regenerate code with specific patch instructions |
| `environment` | Missing package or dependency | Add the package and rebuild the sandbox image    |
| `test_error`  | Bug in the generated test     | Fix the test expectations                        |

If the same error persists after a fix, the plugin reclassifies it (e.g., `logic` to `test_error`) to try the other approach.

In Agent mode, the agent diagnoses and fixes issues autonomously based on error output.

## Durable execution

Code generation is expensive — it involves multiple LLM calls, image builds, and sandbox executions. Without durability, a transient failure in the pipeline (network blip, OOM, downstream service error) would force the entire process to restart from scratch: regenerating code, rebuilding images, re-running sandboxes, making additional LLM calls.

Flyte solves this through two complementary mechanisms: **replay logs** and **caching**.

### Replay logs

Flyte maintains a replay log that records every trace and task execution within a run. When a task crashes and retries, the system replays the log from the previous attempt rather than recomputing everything:

- No additional model calls
- No code regeneration
- No sandbox re-execution
- No container rebuilds

The workflow breezes through the earlier steps and resumes from the failure point. This applies as long as the traces and tasks execute in the same order and use the same inputs as the first attempt.

### Caching

Separately, Flyte can cache task results across runs. With `cache="auto"`, sandbox executions (image builds, test runs, code execution) are cached. This is useful when you re-run the same pipeline — not just when recovering from a crash, but across entirely separate invocations with the same inputs.

Together, replay logs handle crash recovery within a run, and caching avoids redundant work across runs.

### Non-determinism in Agent mode

One challenge with agents is that they are inherently non-deterministic — the sequence of actions can vary between runs, which could break replay.

In practice, the codegen agent follows a predictable pattern (write code, generate tests, run tests, inspect results), which works in replay's favor. The plugin also embeds logic that instructs the agent not to regenerate or re-execute steps that already completed successfully in the first run. This acts as an additional safety check alongside the replay log to account for non-determinism.

![Agent](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/codegen/agent.png)

On the first attempt, the full pipeline runs. If a transient failure occurs, the system instantly replays the traces (which track model calls) and sandbox executions, allowing the pipeline to resume from the point of failure.

![Durability](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/codegen/durability.png)

## Observability

### LiteLLM backend

- Logs every iteration with attempt count, error type, and package changes
- Tracks total input/output tokens across all LLM calls (available on `result.total_input_tokens` and `result.total_output_tokens`)
- Results include full conversation history for debugging (`result.conversation_history`)

### Agent backend

- Traces each tool call (name + input) via `PostToolUse` hooks
- Traces tool failures via `PostToolUseFailure` hooks
- Traces a summary when the agent finishes (total tool calls, tool distribution, final image/packages)
- Classifies Bash commands as safe, denied, or intercepted (for sandbox execution)
- All traces appear in the Flyte UI

## Examples

### Processing CSVs with different schemas

Generate code that handles varying CSV formats, then run on real data:

```python{hl_lines=[1, 3, 14, 16, 27]}
from flyteplugins.codegen import AutoCoderAgent

agent = AutoCoderAgent(
    name="sales-processor",
    model="gpt-4.1",
    max_iterations=5,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    litellm_params={"temperature": 0.2, "max_tokens": 4096},
)

@env.task
async def process_sales(csv_file: File) -> dict[str, float | int]:
    result = await agent.generate.aio(
        prompt="Read the CSV and compute total_revenue, total_units, and transaction_count.",
        samples={"csv_data": csv_file},
        outputs={
            "total_revenue": float,
            "total_units": int,
            "transaction_count": int,
        },
    )

    if not result.success:
        raise RuntimeError(f"Code generation failed: {result.error}")

    total_revenue, total_units, transaction_count = await result.run.aio()

    return {
        "total_revenue": total_revenue,
        "total_units": total_units,
        "transaction_count": transaction_count,
    }
```

### DataFrame analysis with constraints

Pass DataFrames directly and enforce business rules with constraints:

```python{hl_lines=[10, "15-19"]}
agent = AutoCoderAgent(
    model="gpt-4.1",
    name="sensor-analysis",
    base_packages=["numpy"],
    max_sample_rows=30,
)

@env.task
async def analyze_sensors(sensor_df: pd.DataFrame) -> tuple[File, int]:
    result = await agent.generate.aio(
        prompt="""Analyze IoT sensor data. For each sensor, calculate mean/min/max
temperature, mean humidity, and count warnings. Output a summary CSV.""",
        samples={"readings": sensor_df},
        constraints=[
            "Temperature values must be between -40 and 60 Celsius",
            "Humidity values must be between 0 and 100 percent",
            "Output report must have one row per unique sensor_id",
        ],
        outputs={
            "report": File,
            "total_anomalies": int,
        },
    )

    if not result.success:
        raise RuntimeError(f"Code generation failed: {result.error}")

    task = result.as_task(
        name="run-sensor-analysis",
        resources=flyte.Resources(cpu=1, memory="512Mi"),
    )

    return await task.aio(readings=result.original_samples["readings"])
```

### Agent mode

The same task using Claude as an autonomous agent:

```python{hl_lines=[3]}
agent = AutoCoderAgent(
    name="sales-agent",
    backend="claude",
    model="claude-sonnet-4-5-20250929",
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

@env.task
async def process_sales_with_agent(csv_file: File) -> dict[str, float | int]:
    result = await agent.generate.aio(
        prompt="Read the CSV and compute total_revenue, total_units, and transaction_count.",
        samples={"csv_data": csv_file},
        outputs={
            "total_revenue": float,
            "total_units": int,
            "transaction_count": int,
        },
    )

    if not result.success:
        raise RuntimeError(f"Agent code generation failed: {result.error}")

    total_revenue, total_units, transaction_count = await result.run.aio()

    return {
        "total_revenue": total_revenue,
        "total_units": total_units,
        "transaction_count": transaction_count,
    }
```

## Configuration

### LiteLLM parameters

Tune model behavior with `litellm_params`:

```python{hl_lines=["5-8"]}
agent = AutoCoderAgent(
    name="my-task",
    model="anthropic/claude-sonnet-4-20250514",
    api_key="ANTHROPIC_API_KEY",
    litellm_params={
        "temperature": 0.3,
        "max_tokens": 4000,
    },
)
```

### Image configuration

Control the registry and Python version for sandbox images:

```python{hl_lines=["6-10"]}
from flyte.sandbox import ImageConfig

agent = AutoCoderAgent(
    name="my-task",
    model="gpt-4.1",
    image_config=ImageConfig(
        registry="my-registry.io",
        registry_secret="registry-creds",
        python_version=(3, 12),
    ),
)
```

### Skipping tests

Set `skip_tests=True` to skip test generation and execution. The agent still generates code, detects packages, and builds the sandbox image, but does not generate or run tests.

```python{hl_lines=[4]}
agent = AutoCoderAgent(
    name="my-task",
    model="gpt-4.1",
    skip_tests=True,
)
```

> [!NOTE]
> `skip_tests` only applies to LiteLLM mode. In Agent mode, the agent autonomously decides when to test.

### Base packages

Ensure specific packages are always installed in every sandbox:

```python{hl_lines=[4]}
agent = AutoCoderAgent(
    name="my-task",
    model="gpt-4.1",
    base_packages=["numpy", "pandas"],
)
```

## Best practices

- **One agent per task.** Each `generate()` call builds its own sandbox image and manages its own package state. Running multiple agents in the same task can cause resource contention and makes failures harder to diagnose.
- **Keep `cache="auto"` (the default).** Caching flows to all internal sandboxes, making retries near-instant. Use `"disable"` during development if you want fresh executions, or `"override"` to force re-execution and update the cached result.
- **Set `max_iterations` conservatively.** Start with 5-10 iterations. If the model cannot produce correct code in that budget, the prompt or constraints likely need refinement.
- **Provide constraints for data-heavy tasks.** Explicit constraints (e.g., `"quantity must be positive"`) produce better schemas and better generated code.
- **Inspect `result.generated_schemas`.** Review the inferred Pandera schemas to verify the model understood your data structure correctly.

## API reference

### `AutoCoderAgent` constructor

| Parameter         | Type              | Default        | Description                                                                            |
| ----------------- | ----------------- | -------------- | -------------------------------------------------------------------------------------- |
| `name`            | `str`             | `"auto-coder"` | Unique name for tracking and image naming                                              |
| `model`           | `str`             | `"gpt-4.1"`    | LiteLLM model identifier                                                               |
| `backend`         | `str`             | `"litellm"`    | Execution backend: `"litellm"` or `"claude"`                                           |
| `system_prompt`   | `str`             | `None`         | Custom system prompt override                                                          |
| `api_key`         | `str`             | `None`         | Name of the environment variable containing the LLM API key (e.g., `"OPENAI_API_KEY"`) |
| `api_base`        | `str`             | `None`         | Custom API base URL                                                                    |
| `litellm_params`  | `dict`            | `None`         | Extra LiteLLM params (temperature, max_tokens, etc.)                                   |
| `base_packages`   | `list[str]`       | `None`         | Always-install pip packages                                                            |
| `resources`       | `flyte.Resources` | `None`         | Resources for sandbox execution (default: 1 CPU, 1Gi)                                  |
| `image_config`    | `ImageConfig`     | `None`         | Registry, secret, and Python version                                                   |
| `max_iterations`  | `int`             | `10`           | Max generate-test-fix iterations (LiteLLM mode)                                        |
| `max_sample_rows` | `int`             | `100`          | Rows to sample from data for LLM context                                               |
| `skip_tests`      | `bool`            | `False`        | Skip test generation and execution (LiteLLM mode)                                      |
| `sandbox_retries` | `int`             | `0`            | Flyte task-level retries for each sandbox execution                                    |
| `timeout`         | `int`             | `None`         | Timeout in seconds for sandboxes                                                       |
| `env_vars`        | `dict[str, str]`  | `None`         | Environment variables for sandboxes                                                    |
| `secrets`         | `list[Secret]`    | `None`         | Flyte secrets for sandboxes                                                            |
| `cache`           | `str`             | `"auto"`       | Cache behavior: `"auto"`, `"override"`, or `"disable"`                                 |
| `agent_max_turns` | `int`             | `50`           | Max turns when `backend="claude"`                                                      |

### `generate()` parameters

| Parameter     | Type                           | Default  | Description                                                                             |
| ------------- | ------------------------------ | -------- | --------------------------------------------------------------------------------------- |
| `prompt`      | `str`                          | required | Natural-language task description                                                       |
| `schema`      | `str`                          | `None`   | Free-form context about data formats or target structures                               |
| `constraints` | `list[str]`                    | `None`   | Natural-language constraints (e.g., `"quantity must be positive"`)                      |
| `samples`     | `dict[str, File \| DataFrame]` | `None`   | Sample data. DataFrames are auto-converted to CSV files.                                |
| `inputs`      | `dict[str, type]`              | `None`   | Non-sample input types (e.g., `{"threshold": float}`)                                   |
| `outputs`     | `dict[str, type]`              | `None`   | Output types. Supported: `str`, `int`, `float`, `bool`, `datetime`, `timedelta`, `File` |

### `CodeGenEvalResult` fields

| Field                      | Type                      | Description                                               |
| -------------------------- | ------------------------- | --------------------------------------------------------- |
| `success`                  | `bool`                    | Whether tests passed                                      |
| `solution`                 | `CodeSolution`            | Generated code (`.code`, `.language`, `.system_packages`) |
| `tests`                    | `str`                     | Generated test code                                       |
| `output`                   | `str`                     | Test output                                               |
| `exit_code`                | `int`                     | Test exit code                                            |
| `error`                    | `str \| None`             | Error message if failed                                   |
| `attempts`                 | `int`                     | Number of iterations used                                 |
| `image`                    | `str`                     | Built sandbox image with all dependencies                 |
| `detected_packages`        | `list[str]`               | Pip packages detected                                     |
| `detected_system_packages` | `list[str]`               | Apt packages detected                                     |
| `generated_schemas`        | `dict[str, str] \| None`  | Pandera schemas as Python code strings                    |
| `data_context`             | `str \| None`             | Extracted data context                                    |
| `original_samples`         | `dict[str, File] \| None` | Sample data as Files (defaults for `run()`/`as_task()`)   |
| `total_input_tokens`       | `int`                     | Total input tokens across all LLM calls                   |
| `total_output_tokens`      | `int`                     | Total output tokens across all LLM calls                  |
| `conversation_history`     | `list[dict]`              | Full LLM conversation history for debugging               |

### `CodeGenEvalResult` methods

| Method                              | Description                                                        |
| ----------------------------------- | ------------------------------------------------------------------ |
| `result.run(**overrides)`           | Execute generated code in a sandbox. Sample data used as defaults. |
| `await result.run.aio(**overrides)` | Async version of `run()`.                                          |
| `result.as_task(name, ...)`         | Create a reusable callable sandbox task from the generated code.   |

Both `run()` and `as_task()` accept optional `name`, `resources`, `retries`, `timeout`, `env_vars`, `secrets`, and `cache` parameters.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/mlflow ===

# MLflow

The MLflow plugin integrates [MLflow](https://mlflow.org/) experiment tracking with Flyte. It provides a `@mlflow_run` decorator that automatically manages MLflow runs within Flyte tasks, with support for autologging, parent-child run sharing, distributed training, and auto-generated UI links.

The decorator works with both sync and async tasks.

## Installation

```bash
pip install flyteplugins-mlflow
```

Requires `mlflow` and `flyte`.

## Quick start

```python{hl_lines=[3, 9, "13-16", 22]}
import flyte
import mlflow
from flyteplugins.mlflow import mlflow_run, get_mlflow_run

env = flyte.TaskEnvironment(
    name="mlflow-tracking",
    resources=flyte.Resources(cpu=1, memory="500Mi"),
    image=flyte.Image.from_debian_base(name="mlflow_example").with_pip_packages(
        "flyteplugins-mlflow"
    ),
)

@mlflow_run(
    tracking_uri="http://localhost:5000",
    experiment_name="my-experiment",
)
@env.task
async def train_model(learning_rate: float) -> str:
    mlflow.log_param("lr", learning_rate)
    mlflow.log_metric("loss", 0.42)

    run = get_mlflow_run()
    return run.info.run_id
```

![Link](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/link.png)

![Mlflow UI](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/mlflow_dashboard.png)

> [!NOTE]
> `@mlflow_run` must be the outermost decorator, before `@env.task`:
>
> ```python{hl_lines=["1-2"]}
> @mlflow_run          # outermost
> @env.task            # innermost
> async def my_task(): ...
> ```

## Autologging

Enable MLflow's autologging to automatically capture parameters, metrics, and models without manual `mlflow.log_*` calls.

### Generic autologging

```python{hl_lines=[1]}
@mlflow_run(autolog=True)
@env.task
async def train():
    from sklearn.linear_model import LogisticRegression

    model = LogisticRegression()
    model.fit(X, y)  # Parameters, metrics, and model are logged automatically
```

### Framework-specific autologging

Pass `framework` to use a framework-specific autolog implementation:

```python{hl_lines=[3]}
@mlflow_run(
    autolog=True,
    framework="sklearn",
    log_models=True,
    log_datasets=False,
)
@env.task
async def train_sklearn():
    from sklearn.ensemble import RandomForestClassifier

    model = RandomForestClassifier(n_estimators=100)
    model.fit(X_train, y_train)
```

Supported frameworks include any framework with an `mlflow.{framework}.autolog()` function. You can find the full list of supported frameworks [here](https://mlflow.org/docs/latest/ml/tracking/autolog/#supported-libraries).

You can pass additional autolog parameters via `autolog_kwargs`:

```python{hl_lines=[4]}
@mlflow_run(
    autolog=True,
    framework="pytorch",
    autolog_kwargs={"log_every_n_epoch": 5},
)
@env.task
async def train_pytorch():
    ...
```

![Autolog](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/autolog.png)

## Run modes

The `run_mode` parameter controls how MLflow runs are created and shared across tasks:

| Mode               | Behavior                                                              |
| ------------------ | --------------------------------------------------------------------- |
| `"auto"` (default) | Reuse the parent's run if one exists, otherwise create a new run      |
| `"new"`            | Always create a new independent run                                   |
| `"nested"`         | Create a new run nested under the parent via `mlflow.parentRunId` tag |

### Sharing a run across tasks

With `run_mode="auto"` (the default), child tasks reuse the parent's MLflow run:

```python{hl_lines=[1, 5, 7]}
@mlflow_run
@env.task
async def parent_task():
    mlflow.log_param("stage", "parent")
    await child_task()  # Shares the same MLflow run

@mlflow_run
@env.task
async def child_task():
    mlflow.log_metric("child_metric", 1.0)  # Logged to the parent's run
```

### Creating independent runs

Use `run_mode="new"` when a task should always create its own top-level MLflow run, completely independent of any parent:

```python{hl_lines=[1]}
@mlflow_run(run_mode="new")
@env.task
async def standalone_experiment():
    mlflow.log_param("experiment_type", "baseline")
    mlflow.log_metric("accuracy", 0.95)
```

### Nested runs

Use `run_mode="nested"` to create a child run that appears under the parent in the MLflow UI. This works across processes and containers via the `mlflow.parentRunId` tag.

![Nested runs](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/mlflow_hpo.png)

This is the recommended pattern for hyperparameter optimization, where each trial should be tracked as a child of the parent study run:

```python{hl_lines=[1, 2, 15, "22-25"]}
from flyteplugins.mlflow import Mlflow

@mlflow_run(run_mode="nested")
@env.task(links=[Mlflow()])
async def run_trial(trial_number: int, n_estimators: int, max_depth: int) -> float:
    """Each trial creates a nested MLflow run under the parent."""
    mlflow.log_params({"n_estimators": n_estimators, "max_depth": max_depth})
    mlflow.log_param("trial_number", trial_number)

    model = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth)
    model.fit(X_train, y_train)

    rmse = float(np.sqrt(mean_squared_error(y_val, model.predict(X_val))))
    mlflow.log_metric("rmse", rmse)
    return rmse

@mlflow_run
@env.task
async def hpo_search(n_trials: int = 30) -> str:
    """Parent run tracks the overall study."""
    run = get_mlflow_run()
    mlflow.log_param("n_trials", n_trials)

    # Run trials in parallel — each gets a nested MLflow run
    rmses = await asyncio.gather(
        *(run_trial(trial_number=i, **params) for i, params in enumerate(trial_params))
    )

    mlflow.log_metric("best_rmse", min(rmses))
    return run.info.run_id
```

![HPO](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/hpo.png)

## Workflow-level configuration

Use `mlflow_config()` with `flyte.with_runcontext()` to set MLflow configuration for an entire workflow. All `@mlflow_run`-decorated tasks in the workflow inherit these settings:

```python{hl_lines=[1, "4-8"]}
from flyteplugins.mlflow import mlflow_config

r = flyte.with_runcontext(
    custom_context=mlflow_config(
        tracking_uri="http://localhost:5000",
        experiment_id="846992856162999",
        tags={"team": "ml"},
    )
).run(train_model, learning_rate=0.001)
```

This eliminates the need to repeat `tracking_uri` and experiment settings on every `@mlflow_run` decorator.

### Per-task overrides

Use `mlflow_config()` as a context manager inside a task to override configuration for specific child tasks:

```python{hl_lines=[6]}
@mlflow_run
@env.task
async def parent_task():
    await shared_child()  # Inherits parent config

    with mlflow_config(run_mode="new", tags={"role": "independent"}):
        await independent_child()  # Gets its own run
```

### Configuration priority

Settings are resolved in priority order:

1. Explicit `@mlflow_run` decorator arguments
2. `mlflow_config()` context configuration
3. Environment variables (for `tracking_uri`)
4. MLflow defaults

## Distributed training

In distributed training, only rank 0 logs to MLflow by default. The plugin detects rank automatically from the `RANK` environment variable:

```python{hl_lines=[1, "4-6"]}
@mlflow_run
@env.task
async def distributed_train():
    # Only rank 0 creates an MLflow run and logs metrics.
    # Other ranks execute the task function directly without
    # creating an MLflow run or incurring any MLflow overhead.
    ...
```

On non-rank-0 workers, no MLflow run is created and `get_mlflow_run()` returns `None`. The task function still executes normally — only the MLflow instrumentation is skipped.

![Distributed training](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/distributed_training.png)

You can also set rank explicitly:

```python{hl_lines=[1]}
@mlflow_run(rank=0)
@env.task
async def train():
    ...
```

## MLflow UI links

The `Mlflow` link class displays links to the MLflow UI in the Flyte UI.

Since the MLflow run is created inside the task at execution time, the run URL cannot be determined before the task starts. Links are only shown when a run URL is already available from context, either because a parent task created the run, or because an explicit URL is provided.

The recommended pattern is for the parent task to create the MLflow run, and child tasks that inherit the run (via `run_mode="auto"`) display the link to that run. For nested runs (`run_mode="nested"`), children display a link to the parent run.

### Setup

Set `link_host` via `mlflow_config()` and attach `Mlflow()` links to child tasks:

```python{hl_lines=[4, 17]}
from flyteplugins.mlflow import Mlflow, mlflow_config

@mlflow_run
@env.task(links=[Mlflow()])
async def child_task():
    ...  # Link points to the parent's MLflow run

@mlflow_run
@env.task
async def parent_task():
    await child_task()

if __name__ == "__main__":
    r = flyte.with_runcontext(
        custom_context=mlflow_config(
            tracking_uri="http://localhost:5000",
            link_host="http://localhost:5000",
        )
    ).run(parent_task)
```

> [!NOTE]
> `Mlflow()` is instantiated without a `link` argument because the URL is auto-generated at runtime. When the parent task creates an MLflow run, the plugin builds the URL from `link_host` and the run's experiment/run IDs, then propagates it to child tasks via the Flyte context. Passing an explicit `link` would bypass this auto-generation.

### Custom URL templates

The default link format is:

```
{host}/#/experiments/{experiment_id}/runs/{run_id}
```

For platforms like Databricks that use a different URL structure, provide a custom template:

```python{hl_lines=[3]}
mlflow_config(
    link_host="https://dbc-xxx.cloud.databricks.com",
    link_template="{host}/ml/experiments/{experiment_id}/runs/{run_id}",
)
```

### Explicit links

If you know the run URL ahead of time, you can set it directly:

```python{hl_lines=[1]}
@env.task(links=[Mlflow(link="https://mlflow.example.com/#/experiments/1/runs/abc123")])
async def my_task():
    ...
```

### Link behavior by run mode

| Run mode   | Link behavior                                                                                  |
| ---------- | ---------------------------------------------------------------------------------------------- |
| `"auto"`   | Parent link propagates to child tasks sharing the run                                          |
| `"new"`    | Parent link is cleared; no link is shown until the task's own run is available to its children |
| `"nested"` | Parent link is kept and renamed to "MLflow (parent)"                                           |

## Automatic Flyte tags

When running inside Flyte, the plugin automatically tags MLflow runs with execution metadata:

| Tag                 | Description      |
| ------------------- | ---------------- |
| `flyte.action_name` | Task action name |
| `flyte.run_name`    | Flyte run name   |
| `flyte.project`     | Flyte project    |
| `flyte.domain`      | Flyte domain     |

These tags are merged with any user-provided tags.

## API reference

### `mlflow_run` and `mlflow_config`

`mlflow_run` is a decorator that manages MLflow runs for Flyte tasks. `mlflow_config` creates workflow-level configuration or per-task overrides. Both accept the same core parameters:

| Parameter         | Type             | Default  | Description                                                                   |
| ----------------- | ---------------- | -------- | ----------------------------------------------------------------------------- |
| `run_mode`        | `str`            | `"auto"` | `"auto"`, `"new"`, or `"nested"`                                              |
| `tracking_uri`    | `str`            | `None`   | MLflow tracking server URL                                                    |
| `experiment_name` | `str`            | `None`   | MLflow experiment name (raises `ValueError` if combined with `experiment_id`) |
| `experiment_id`   | `str`            | `None`   | MLflow experiment ID (raises `ValueError` if combined with `experiment_name`) |
| `run_name`        | `str`            | `None`   | Human-readable run name (raises `ValueError` if combined with `run_id`)       |
| `run_id`          | `str`            | `None`   | Explicit MLflow run ID (raises `ValueError` if combined with `run_name`)      |
| `tags`            | `dict[str, str]` | `None`   | Tags for the run                                                              |
| `autolog`         | `bool`           | `False`  | Enable MLflow autologging                                                     |
| `framework`       | `str`            | `None`   | Framework for autolog (e.g. `"sklearn"`, `"pytorch"`)                         |
| `log_models`      | `bool`           | `None`   | Log models automatically (requires `autolog`)                                 |
| `log_datasets`    | `bool`           | `None`   | Log datasets automatically (requires `autolog`)                               |
| `autolog_kwargs`  | `dict`           | `None`   | Extra parameters for `mlflow.autolog()`                                       |

Additional keyword arguments are passed to `mlflow.start_run()`.

`mlflow_run` also accepts:

| Parameter | Type  | Default | Description                                              |
| --------- | ----- | ------- | -------------------------------------------------------- |
| `rank`    | `int` | `None`  | Process rank for distributed training (only rank 0 logs) |

`mlflow_config` also accepts:

| Parameter       | Type  | Default | Description                                                                 |
| --------------- | ----- | ------- | --------------------------------------------------------------------------- |
| `link_host`     | `str` | `None`  | MLflow UI host for auto-generating links                                    |
| `link_template` | `str` | `None`  | Custom URL template (placeholders: `{host}`, `{experiment_id}`, `{run_id}`) |

### `get_mlflow_run`

Returns the current `mlflow.ActiveRun` if within a `@mlflow_run`-decorated task. Returns `None` otherwise.

```python
from flyteplugins.mlflow import get_mlflow_run

run = get_mlflow_run()
if run:
    print(run.info.run_id)
```

### `get_mlflow_context`

Returns the current `mlflow_config` settings from the Flyte context, or `None` if no MLflow configuration is set. Useful for inspecting the inherited configuration inside a task:

```python
from flyteplugins.mlflow import get_mlflow_context

@mlflow_run
@env.task
async def my_task():
    config = get_mlflow_context()
    if config:
        print(config.tracking_uri, config.experiment_id)
```

### `Mlflow`

Link class for displaying MLflow UI links in the Flyte console.

| Field  | Type  | Default    | Description                             |
| ------ | ----- | ---------- | --------------------------------------- |
| `name` | `str` | `"MLflow"` | Display name for the link               |
| `link` | `str` | `""`       | Explicit URL (bypasses auto-generation) |

