# Sweeps

W&B sweeps automate hyperparameter optimization by running multiple trials with different parameter combinations. The `@wandb_sweep` decorator creates a sweep and makes it easy to run trials in parallel using Flyte's distributed execution.

## Creating a sweep

Use `@wandb_sweep` to create a W&B sweep when the task executes:

```
import flyte
import wandb
from flyteplugins.wandb import (
    get_wandb_sweep_id,
    wandb_config,
    wandb_init,
    wandb_sweep,
    wandb_sweep_config,
)

env = flyte.TaskEnvironment(
    name="wandb-example",
    image=flyte.Image.from_debian_base(name="wandb-example").with_pip_packages(
        "flyteplugins-wandb"
    ),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

@wandb_init
def objective():
    """Objective function that W&B calls for each trial."""
    wandb_run = wandb.run
    config = wandb_run.config

    # Simulate training with hyperparameters from the sweep
    for epoch in range(config.epochs):
        loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
        wandb_run.log({"epoch": epoch, "loss": loss})

@wandb_sweep
@env.task
async def run_sweep() -> str:
    sweep_id = get_wandb_sweep_id()

    # Run 10 trials
    wandb.agent(sweep_id, function=objective, count=10)

    return sweep_id

if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext(
        custom_context={
            **wandb_config(project="my-project", entity="my-team"),
            **wandb_sweep_config(
                method="random",
                metric={"name": "loss", "goal": "minimize"},
                parameters={
                    "learning_rate": {"min": 0.0001, "max": 0.1},
                    "batch_size": {"values": [16, 32, 64, 128]},
                    "epochs": {"values": [5, 10, 20]},
                },
            ),
        },
    ).run(run_sweep)

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/sweep.py*

The `@wandb_sweep` decorator:

- Creates a W&B sweep when the task starts
- Makes the sweep ID available via `get_wandb_sweep_id()`
- Adds a link to the main sweeps page in the Flyte UI

Use `wandb_sweep_config()` to define the sweep parameters. This is passed to W&B's sweep API.

> **📝 Note**
>
> Random and Bayesian searches run indefinitely, and the sweep remains in the `Running` state until you stop it.
> You can stop a running sweep from the Weights & Biases UI or from the command line.

## Running parallel agents

Flyte's distributed execution makes it easy to run multiple sweep agents in parallel, each on its own compute resources:

```
import asyncio
from datetime import timedelta

import flyte
import wandb
from flyteplugins.wandb import (
    get_wandb_sweep_id,
    wandb_config,
    wandb_init,
    wandb_sweep,
    wandb_sweep_config,
    get_wandb_context,
)

env = flyte.TaskEnvironment(
    name="wandb-parallel-sweep-example",
    image=flyte.Image.from_debian_base(
        name="wandb-parallel-sweep-example"
    ).with_pip_packages("flyteplugins-wandb"),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

@wandb_init
def objective():
    wandb_run = wandb.run
    config = wandb_run.config

    for epoch in range(config.epochs):
        loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
        wandb_run.log({"epoch": epoch, "loss": loss})

@wandb_sweep
@env.task
async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int:
    """Single agent that runs a subset of trials."""
    wandb.agent(
        sweep_id, function=objective, count=count, project=get_wandb_context().project
    )
    return agent_id

@wandb_sweep
@env.task
async def run_parallel_sweep(total_trials: int = 20, trials_per_agent: int = 5) -> str:
    """Orchestrate multiple agents running in parallel."""
    sweep_id = get_wandb_sweep_id()

    num_agents = (total_trials + trials_per_agent - 1) // trials_per_agent

    # Launch agents in parallel, each with its own resources
    agent_tasks = [
        sweep_agent.override(
            resources=flyte.Resources(cpu="2", memory="4Gi"),
            retries=3,
            timeout=timedelta(minutes=30),
        )(agent_id=i, sweep_id=sweep_id, count=trials_per_agent)
        for i in range(num_agents)
    ]

    await asyncio.gather(*agent_tasks)
    return sweep_id

if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext(
        custom_context={
            **wandb_config(project="my-project", entity="my-team"),
            **wandb_sweep_config(
                method="random",
                metric={"name": "loss", "goal": "minimize"},
                parameters={
                    "learning_rate": {"min": 0.0001, "max": 0.1},
                    "batch_size": {"values": [16, 32, 64]},
                    "epochs": {"values": [5, 10, 20]},
                },
            ),
        },
    ).run(
        run_parallel_sweep,
        total_trials=20,
        trials_per_agent=5,
    )

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/parallel_sweep.py*

This pattern provides:

- **Distributed execution**: Each agent runs on separate compute nodes
- **Resource allocation**: Specify CPU, memory, and GPU per agent
- **Fault tolerance**: Failed agents can retry without affecting others
- **Timeout protection**: Prevent runaway trials

> **📝 Note**
>
> `run_parallel_sweep` links to the main Weights & Biases sweeps page and `sweep_agent` links to the specific sweep URL because we cannot determine the sweep ID at link rendering time.

![Sweep](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/wandb/sweep.png)

## Writing objective functions

The objective function is called by `wandb.agent()` for each trial. It must be a regular Python function decorated with `@wandb_init`:

```python {hl_lines=["1-2", "5-6"]}
@wandb_init
def objective():
    """Objective function for sweep trials."""
    # Access hyperparameters from wandb.run.config
    run = wandb.run
    config = run.config

    # Your training code
    model = create_model(
        learning_rate=config.learning_rate,
        hidden_size=config.hidden_size,
    )

    for epoch in range(config.epochs):
        train_loss = train_epoch(model)
        val_loss = validate(model)

        # Log metrics - W&B tracks these for the sweep
        run.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "val_loss": val_loss,
        })

    # The final val_loss is used by the sweep to rank trials
```

Key points:

- Use `@wandb_init` on the objective function (not `@env.task`)
- Access hyperparameters via `wandb.run.config` (not `get_wandb_run()` since this is outside Flyte context)
- Log the metric specified in `wandb_sweep_config(metric=...)` so the sweep can optimize it
- The function is called multiple times by `wandb.agent()`, once per trial

---
**Source**: https://github.com/unionai/unionai-docs/blob/main/content/integrations/wandb/sweeps.md
**HTML**: https://www.union.ai/docs/v2/union/integrations/wandb/sweeps/
