=== PAGE: https://www.union.ai/docs/v2/byoc ===
# Documentation
Welcome to the documentation.
## Subpages
- **Union.ai BYOC**
- **Tutorials**
- **Integrations**
- **Reference**
- **Community**
- **Release Notes**
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide ===
# Union.ai BYOC
Union.ai empowers AI development teams to rapidly ship high-quality code to production by offering optimized performance, unparalleled resource efficiency, and a delightful workflow authoring experience. With Union.ai your team can:
* Run complex AI workloads with performance, scale, and efficiency.
* Achieve millisecond-level execution times with reusable containers.
* Scale out to multiple regions, clusters, and clouds as needed for resource availability, scale, or compliance.
> [!NOTE]
> These are the Union.ai **2.0 beta** docs.
> To switch to [version 1.0](/docs/v1/byoc/) or to another product variant, use the selectors above.
>
> Union.ai is built on top of the leading open-source workflow orchestrator, [Flyte](/docs/v2/flyte/).
>
> Union.ai BYOC (Bring Your Own Cloud) provides all the features of Flyte, plus much more
> in an environment where you keep your data and workflow code on your infrastructure, while Union.ai takes care of the management.
### π‘ **Flyte 2**
Flyte 2 represents a fundamental shift in how AI workflows are written and executed. Learn
more in this section.
### π’ **Getting started**
Install Flyte 2, configure your local IDE, create and run your first task, and inspect the results in 2 minutes.
## Subpages
- **Flyte 2**
- **Getting started**
- **Configure tasks**
- **Build tasks**
- **Run and deploy tasks**
- **Authenticating with Union**
- **Considerations**
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/flyte-2 ===
# Flyte 2
Flyte 2 and Union 2 represent a fundamental shift in how workflows are written and executed in Union.
> **π Ready to get started?**
>
> Ready to get started? Go the **Getting started** guide to install Flyte 2 and run your first task.
## Pure Python execution
Write workflows in pure Python, enabling a more natural development experience and removing the constraints of a
domain-specific language (DSL).
### Sync Python
```
import flyte
env = flyte.TaskEnvironment("sync_example_env")
@env.task
def hello_world(name: str) -> str:
return f"Hello, {name}!"
@env.task
def main(name: str) -> str:
for i in range(10):
hello_world(name)
return "Done"
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, name="World")
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/sync_example.py)
### Async Python
```
import asyncio
import flyte
env = flyte.TaskEnvironment("async_example_env")
@env.task
async def hello_world(name: str) -> str:
return f"Hello, {name}!"
@env.task
async def main(name: str) -> str:
results = []
for i in range(10):
results.append(hello_world(name))
await asyncio.gather(*results)
return "Done"
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, name="World")
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async_example.py)
As you can see in the hello world example, workflows can be constructed at runtime, allowing for more flexible and
adaptive behavior. The Flyte 2 also supports:
- Python's asynchronous programming model to express parallelism.
- Python's native error handling with `try-except` to overridden configurations, like resource requests.
- Predefined static workflows when compile-time safety is critical.
## Simplified API
The new API is more intuitive, with fewer abstractions to learn and a focus on simplicity.
| Use case | Flyte 1 | Flyte 2 |
| ----------------------------- | --------------------------- | --------------------------------------- |
| Environment management | `N/A` | `TaskEnvironment` |
| Perform basic computation | `@task` | `@env.task` |
| Combine tasks into a workflow | `@workflow` | `@env.task` |
| Create dynamic workflows | `@dynamic` | `@env.task` |
| Fanout parallelism | `flytekit.map` | Python `for` loop with `asyncio.gather` |
| Conditional execution | `flytekit.conditional` | Python `if-elif-else` |
| Catching workflow failures | `@workflow(on_failure=...)` | Python `try-except` |
There is no `@workflow` decorator. Instead, "workflows" are authored through a pattern of tasks calling tasks.
Tasks are defined within environments, which encapsulate the context and resources needed for execution.
## Fine-grained reproducibility and recoverability
Flyte tasks support caching via `@env.task(cache=...)`, but tracing with `@flyte.trace` augments task level-caching
even further enabling reproducibility and recovery at the sub-task function level.
```
import flyte
env = flyte.TaskEnvironment(name="trace_example_env")
@flyte.trace
async def call_llm(prompt: str) -> str:
return "Initial response from LLM"
@env.task
async def finalize_output(output: str) -> str:
return "Finalized output"
@env.task(cache=flyte.Cache(behavior="auto"))
async def main(prompt: str) -> str:
output = await call_llm(prompt)
output = await finalize_output(output)
return output
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, prompt="Prompt to LLM")
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/trace.py)
Here the `call_llm` function is called in the same container as `main` that serves as an automated checkpoint with full
observability in the UI. If the task run fails, the workflow is able to recover and replay from where it left off.
## Improved remote functionality
Flyte 2 provides full management of the workflow lifecycle through a standardized API through the CLI and the Python SDK.
| Use case | CLI | Python SDK |
| ------------- | ------------------ | ------------------- |
| Run a task | `flyte run ...` | `flyte.run(...)` |
| Deploy a task | `flyte deploy ...` | `flyte.deploy(...)` |
You can also fetch and run remote (previously deployed) tasks within the course of a running workflow.
```
import flyte
from flyte import remote
env_1 = flyte.TaskEnvironment(name="env_1")
env_2 = flyte.TaskEnvironment(name="env_2")
env_1.add_dependency(env_2)
@env_2.task
async def remote_task(x: str) -> str:
return "Remote task processed: " + x
@env_1.task
async def main() -> str:
remote_task_ref = remote.Task.get("env_2.remote_task", auto_version="latest")
r = await remote_task_ref(x="Hello")
return "main called remote and recieved: " + r
if __name__ == "__main__":
flyte.init_from_config()
d = flyte.deploy(env_1)
print(d[0].summary_repr())
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/remote.py)
## Native Notebook support
Author and run workflows and fetch workflow metadata (I/O and logs) directly from Jupyter notebooks.

## High performance engine
Schedule tasks in milliseconds with reusable containers, which massively increases the throughput of containerized tasks.
```
# Currently required to enable resuable containers
reusable_image = flyte.Image.from_debian_base().with_pip_packages("unionai-reuse>=0.1.3")
env = flyte.TaskEnvironment(
name="reusable-env",
resources=flyte.Resources(memory="1Gi", cpu="500m"),
reusable=flyte.ReusePolicy(
replicas=2, # Create 2 container instances
concurrency=1, # Process 1 task per container at a time
scaledown_ttl=timedelta(minutes=10), # Individual containers shut down after 5 minutes of inactivity
idle_ttl=timedelta(hours=1) # Entire environment shuts down after 30 minutes of no tasks
),
image=reusable_image # Use the container image augmented with the unionai-reuse library.
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/reusable-containers/reuse.py)
Coupled with multi-cluster, multi-cloud, and multi-region support, Flyte 2 can scale to handle even the most demanding
workflows.
## Enhanced UI
New UI with a streamlined and user-friendly experience for authoring and managing workflows.

This UI improves the visualization of workflow execution and monitoring, simplifying access to logs, metadata, and other important information.
## Subpages
- **Flyte 2 > Pure Python**
- **Flyte 2 > Asynchronous model**
- **Flyte 2 > Migration from Flyte 1 to Flyte 2**
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/flyte-2/pure-python ===
# Pure Python
Flyte 2 introduces a new way of writing workflows that is based on pure Python, removing the constraints of a domain-specific language (DSL) and enabling full use of Python's capabilities.
## From `@workflow` DSL to pure Python
| Flyte 1 | Flyte 2 |
| --- | --- |
| `@workflow`-decorated functions are constrained to a subset of Python for defining a static directed acyclic graph (DAG) of tasks. | **No more `@workflow` decorator**: Everything is a `@env.task`, so your top-level βworkflowβ is simply a task that calls other tasks. |
| `@task`-decorated functions could leverage the full power of Python, but only within individual container executions. | `@env.task`s can call other `@env.task`s and be used to construct workflows with dynamic structures using loops, conditionals, try/except, and any Python construct anywhere. |
| Workflows were compiled into static DAGs at registration time, with tasks as the nodes and the DSL defining the structure. | Workflows are simply tasks that call other tasks. Compile-time safety will be available in the future as `compiled_task`. |
### Flyte 1
```python
import flytekit
image = flytekit.ImageSpec(
name="hello-world-image",
packages=["requests"],
)
@flytekit.task(container_image=image)
def mean(data: list[float]) -> float:
return sum(list) / len(list)
@flytekit.workflow
def main(data: list[float]) -> float:
output = mean(data)
# β performing trivial operations in a workflow is not allowed
# output = output / 100
# β if/else is not allowed
# if output < 0:
# raise ValueError("Output cannot be negative")
return output
```
### Flyte 2
```
import flyte
env = flyte.TaskEnvironment(
"hello_world",
image=flyte.Image.from_debian_base().with_pip_packages("requests"),
)
@env.task
def mean(data: list[float]) -> float:
return sum(data) / len(data)
@env.task
def main(data: list[float]) -> float:
output = mean(data)
# β performing trivial operations in a workflow is allowed
output = output / 100
# β if/else is allowed
if output < 0:
raise ValueError("Output cannot be negative")
return output
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/pure-python/flyte_2.py)
These fundamental changes bring several transformative benefits:
- **Flexibility**: Harness the complete Python language for workflow definition, including all control flow constructs previously forbidden in workflows.
- **Dynamic workflows**: Create workflows that adapt to runtime conditions, handle variable data structures, and make decisions based on intermediate results.
- **Natural error handling**: Use standard Python `try`/`except` patterns throughout your workflows, making them more robust and easier to debug.
- **Intuitive composability**: Build complex workflows by naturally composing Python functions, following familiar patterns that any Python developer understands.
## Workflows can still be static when needed
> [!NOTE]
> This feature is coming soon.
The flexibility of dynamic workflows is absolutely needed for many use cases, but there are other scenarios where static workflows are beneficial. For these cases, Flyte 2 will offer compilation of the top-level task of a workflow into a static DAG.
This upcoming feature will provide:
- **Static analysis**: Enable workflow visualization and validation before execution
- **Predictable resources**: Allow precise resource planning and scheduling optimization
- **Traditional tooling**: Support existing DAG-based analysis and monitoring tools
- **Hybrid approach**: Choose between dynamic and static execution based on workflow characteristics
The static compilation system will naturally have limitations compared to fully dynamic workflows:
- **Dynamic fanouts**: Constructs that require runtime data to reify, for example, loops with an iteration-size that depends on intermediate results, will not be compilable.
- However, constructs whose size and scope *can* be determined at registration time, such as fixed-size loops or maps, *will* be compilable.
- **Conditional branching**: Decision trees whose size and structure depend on intermediate results will not be compilable.
- However, conditionals with fixed branch size will be compilable.
For the applications that require a predefined workflow graph, Flyte 2 will enable compilability up to the limits implicit in directed acyclic graphs.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/flyte-2/async ===
# Asynchronous model
## Why we need an async model
The shift to an asynchronous model in Flyte 2 is driven by the need for more efficient and flexible workflow execution.
We believe, in particular, that with the rise of the agentic AI pattern, asynchronous programming has become an essential part of AI/ML engineering and data science toolkit.
With Flyte 2, the entire framework is now written with async constructs, allowing for:
- Seamless overlapping of I/O and independent external operations.
- Composing multiple tasks and external tool invocations within the same Python process.
- Native support of streaming operations for data, observability and downstream invocations.
It is also a natural fit for the expression parallelism in workflows.
### Understanding concurrency vs. parallelism
Before diving into Flyte 2's approach, it's essential to understand the distinction between concurrency and parallelism:
| Concurrency | Parallelism |
| --- | --- |
| Dealing with multiple tasks at once through interleaved execution, even on a single thread. | Executing multiple tasks truly simultaneously across multiple cores or machines. |
| Performance benefits come from allowing the system to switch between tasks when one is waiting for external operations. | This is a subset of concurrency where tasks run at the same time rather than being interleaved. |
### Python's async evolution
Python's asynchronous programming capabilities have evolved significantly:
- **The GIL challenge**: Python's Global Interpreter Lock (GIL) traditionally prevented true parallelism for CPU-bound tasks, limiting threading effectiveness to I/O-bound operations.
- **Traditional solutions**:
- `multiprocessing`: Created separate processes to sidestep the GIL, effective but resource-intensive
- `threading`: Useful for I/O-bound tasks where the GIL could be released during external operations
- **The async revolution**: The `asyncio` library introduced cooperative multitasking within a single thread, using an event loop to manage multiple tasks efficiently.
### Parallelism in Flyte 1 vs Flyte 2
| | Flyte 1 | Flyte 2 |
| --- | --- | --- |
| Parallelism | The workflow DSL automatically parallelized tasks that weren't dependent on each other. The `map` operator allowed running a task multiple times in parallel with different inputs. | Leverages Python's `asyncio` as the primary mechanism for expressing parallelism, but with a crucial difference: **the Flyte orchestrator acts as the event loop**, managing task execution across distributed infrastructure. |
### Core async concepts
- **`async def`**: Declares a function as a coroutine. When called, it returns a coroutine object managed by the event loop rather than executing immediately.
- **`await`**: Pauses coroutine execution and passes control back to the event loop.
In standard Python, this enables other tasks to run while waiting for I/O operations.
In Flyte 2, it signals where tasks can be executed in parallel.
- **`asyncio.gather`**: The primary tool for concurrent execution.
In standard Python, it schedules multiple awaitable objects to run concurrently within a single event loop.
In Flyte 2, it signals to the orchestrator that these tasks can be distributed across separate compute resources.
#### A practical example
Consider this pattern for parallel data processing:
```
import asyncio
import flyte
env = flyte.TaskEnvironment("data_pipeline")
@env.task
async def process_chunk(chunk_id: int, data: str) -> str:
# This could be any computational work - CPU or I/O bound
await asyncio.sleep(1) # Simulating work
return f"Processed chunk {chunk_id}: {data}"
@env.task
async def parallel_pipeline(data_chunks: list[str]) -> list[str]:
# Create coroutines for all chunks
tasks = []
for i, chunk in enumerate(data_chunks):
tasks.append(process_chunk(i, chunk))
# Execute all chunks in parallel
results = await asyncio.gather(*tasks)
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py)
In standard Python, this would provide concurrency benefits primarily for I/O-bound operations.
In Flyte 2, the orchestrator schedules each `process_chunk` task on separate Kubernetes pods or configured plugins, achieving true parallelism for any type of work.
### True parallelism for all workloads
This is where Flyte 2's approach becomes revolutionary: **async syntax is not just for I/O-bound operations**.
The `async`/`await` syntax becomes a powerful way to declare your workflow's parallel structure for any type of computation.
When Flyte's orchestrator encounters `await asyncio.gather(...)`, it understands that these tasks are independent and can be executed simultaneously across different compute resources.
This means you achieve true parallelism for:
- **CPU-bound computations**: Heavy mathematical operations, model training, data transformations
- **I/O-bound operations**: Database queries, API calls, file operations
- **Mixed workloads**: Any combination of computational and I/O tasks
The Flyte platform handles the complex orchestration while you express parallelism using intuitive `async` syntax.
## Bridging the transition: Sync support and migration tools
### Seamless synchronous task support
Recognizing that many existing codebases use synchronous functions, Flyte 2 provides seamless backward compatibility:
```
@env.task
def legacy_computation(x: int) -> int:
# Existing synchronous function works unchanged
return x * x + 2 * x + 1
@env.task
async def modern_workflow(numbers: list[int]) -> list[int]:
# Call sync tasks from async context using .aio()
tasks = []
for num in numbers:
tasks.append(legacy_computation.aio(num))
results = await asyncio.gather(*tasks)
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py)
Under the hood, Flyte automatically "asyncifies" synchronous functions, wrapping them to participate seamlessly in the async execution model.
You don't need to rewrite existing codeβjust leverage the `.aio()` method when calling sync tasks from async contexts.
### The `flyte.map` function: Familiar patterns
For scenarios that previously used Flyte 1's `map` operation, Flyte 2 provides `flyte.map` as a direct replacement.
The new `flyte.map` can be used either in synchronous or asynchronous contexts, allowing you to express parallelism without changing your existing patterns.
### Sync Map
```
@env.task
def sync_map_example(n: int) -> list[str]:
# Synchronous version for easier migration
results = []
for result in flyte.map(process_item, range(n)):
if isinstance(result, Exception):
raise result
results.append(result)
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py)
### Async Map
```
@env.task
async def async_map_example(n: int) -> list[str]:
# Async version using flyte.map - exact pattern from SDK examples
results = []
async for result in flyte.map.aio(process_item, range(n), return_exceptions=True):
if isinstance(result, Exception):
raise result
results.append(result)
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/flyte-2/async/async.py)
The `flyte.map` function provides:
- **Dual interfaces**: `flyte.map.aio()` for async contexts, `flyte.map()` for sync contexts.
- **Built-in error handling**: `return_exceptions` parameter for graceful failure handling. This matches the `asyncio.gather` interface,
allowing you to decide how to handle errors.
If you are coming from Flyte 1, it allows you to replace `min_success_ratio` in a more flexible way.
- **Automatic UI grouping**: Creates logical groups for better workflow visualization.
- **Concurrency control**: Optional limits for resource management.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/flyte-2/migration ===
# Migration from Flyte 1 to Flyte 2
> [!NOTE]
> Automated migration from Flyte 1 to Flyte 2 is coming soon.
Flyte 2 will soon offer automated migration from Flyte 1 to 2.
In the meantime you can migrate manually by following the steps below.:
### 1. Move task configuration to a `TaskEnvironment` object
Instead of configuring the image, hardware resources, and so forth, directly in the task decorator. You configure it in `TaskEnvironment` object. For example:
```python
env = flyte.TaskEnvironment(name="my_task_env")
```
### 2. Replace workflow decorators
Then, you replace the `@workflow` and `@task` decorators with `@env.task` decorators.
### Flyte 1
Here's a simple hello world example with fan-out.
```python
import flytekit
@flytekit.task
def hello_world(name: str) -> str:
return f"Hello, {name}!"
@flytekit.workflow
def main(names: list[str]) -> list[str]:
return flytekit.map(hello_world)(names)
```
### Flyte 2 Sync
Change all the decorators to `@env.task` and swap out `flytekit.map` with `flyte.map`.
Notice that `flyte.map` is a drop-in replacement for Python's built-in `map` function.
```diff
-@flytekit.task
+@env.task
def hello_world(name: str) -> str:
return f"Hello, {name}!"
-@flytekit.workflow
+@env.task
def main(names: list[str]) -> list[str]:
return flyte.map(hello_world, names)
```
> **π Note**
>
> Note that the reason our task decorator uses `env` is simply because that is the variable to which we assigned the `TaskEnvironment` above.
### Flyte 2 Async
To take advantage of full concurrency (not just parallelism), use Python async
syntax and the `asyncio` standard library to implement fa-out.
```diff
+import asyncio
@env.task
-def hello_world(name: str) -> str:
+async def hello_world(name: str) -> str:
return f"Hello, {name}!"
@env.task
-def main(names: list[str]) -> list[str]:
+async def main(names: list[str]) -> list[str]:
- return flyte.map(hello_world, names)
+ return await asyncio.gather(*[hello_world(name) for name in names])
```
> **π Note**
>
> To use Python async syntax, you need to:
> - Use `asyncio.gather()` or `flyte.map()` for parallel execution
> - Add `async`/`await` keywords where you want parallelism
> - Keep existing sync task functions unchanged
>
> Learn more about about the benefits of async in the **Flyte 2 > Asynchronous model** guide.
### 3. Leverage enhanced capabilities
- Add conditional logic and loops within workflows
- Implement proper error handling with try/except
- Create dynamic workflows that adapt to runtime conditions
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/getting-started ===
# Getting started
This section gives you a quick introduction to writing and running workflows on Union and Flyte 2.
## Prerequisites
### Install uv
First, [install the `uv` package manager](https://docs.astral.sh/uv/getting-started/installation/).
> [!NOTE]
> You will need to use the [`uv` package manager](https://docs.astral.sh/uv/) to run the examples in this guide.
> In particular, we leverage `uv`'s ability to [embed dependencies directly in scripts](https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies).
### Install Python 3.10 or later
Flyte 2 requires Python 3.10 or later.
Install the most recent version of Python (>= 3.10) compatible with your codebase and pin it.
For example, to install and pin Python 3.13, do the following:
```shell
uv python install 3.13
uv python pin 3.13 --global
```
### Create a Python virtual environment
In your working directory, create a Python virtual environment and activate it:
```shell
uv venv
source .venv/bin/activate
```
## Install the `flyte` package
Install the latest `flyte` package in the virtual environment (we are currently in beta, so you will have to enable prerelease installation):
```shell
uv pip install --no-cache --prerelease=allow --upgrade flyte
```
## Create a config.yaml
Next, create a configuration file that points to your Union instance.
Use the **Flyte CLI > flyte > flyte create > flyte create config** command, making the following changes:
- Replace `my-org.my-company.com` with the actual URL of your Union backend instance.
You can simply copy the domain part of the URL from your browser when logged into your backend instance.
- Replace `my-project` with an actual project.
The project you specify must already exist on your Union backend instance.
```shell
flyte create config \
--endpoint my-org.my-company.com \
--builder remote \
--domain development \
--project my-project
```
By default, this will create a `./.flyte/config.yaml` file in your current working directory.
See **Getting started > Local setup > Setting up a configuration file** for details.
> **π Note**
>
> Run `flyte get config` to see the current configuration file being used by the `flyte` CLI.
## Hello world example
Create a file called `hello.py` with the following content:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# ]
# main = "main"
# params = "x_list=[1,2,3,4,5,6,7,8,9,10]"
# ///
import flyte
# A TaskEnvironment provides a way of grouping the configuration used by tasks.
env = flyte.TaskEnvironment(name="hello_world")
# Use a TaskEnvironment to define tasks, which are regular Python functions.
@env.task
def fn(x: int) -> int: # Type annotations are recommended.
slope, intercept = 2, 5
#raise ValueError("I will fail!")
return slope * x + intercept
# Tasks can call other tasks.
# Each task defined with a given TaskEnvironment will run in its own separate container,
# but the containers will all be configured identically.
@env.task
def main(x_list: list[int] = list(range(10))) -> float:
x_len = len(x_list)
if x_len < 10:
raise ValueError(f"x_list doesn't have a larger enough sample size, found: {x_len}")
# flyte.map is like Python map, but runs in parallel.
y_list = list(flyte.map(fn, x_list))
y_mean = sum(y_list) / len(y_list)
return y_mean
# Running this script locally will perform a flyte.run,
# which will deploy your task code to your remote Union/Flyte instance.
if __name__ == "__main__":
# Initialize Flyte from a config file.
flyte.init_from_config()
# Run your tasks remotely inline and pass parameter data.
r = flyte.run(main, x_list=list(range(10)))
# Print various attributes of the run.
print(r.name)
print(r.url)
# Stream the logs from the remote run to the terminal.
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/getting-started/hello.py)
## Understanding the code
In the code above we do the following:
- Import the `flyte` package.
- Define a `TaskEnvironment` to group the configuration used by tasks.
- Define two tasks using the `@env.task` decorator.
- Tasks are regular Python functions, but each runs in its own container.
- When deployed to your Union/Flyte instance, each task execution will run in its own separate container.
- Both tasks use the same `env` (the same `TaskEnvironment`) so, while each runs in its own container, those containers will be configured identically.
## Running the code
Make sure that your `config.yaml` file is in the same directory as your `hello.py` script.
Now, run the script with:
```shell
uv run --prerelease allow hello.py
```
The main guard section in the script performs a `flyte.init_from_config` to set up the connection with your Union/Flyte instance and a `flyte.run` to send your task code to that instance and execute it there.
> [!NOTE]
> The example scripts in this guide have a main guard that programmatically deploys and runs the tasks defined in the same file.
> All you have to do is execute the script itself.
> You can also deploy tasks using the `flyte` CLI instead. We will cover this in a later section.
## Viewing the results
In your terminal, you should see output like this:
```shell
cg9s54pksbjsdxlz2gmc
https://my-instance.example.com/v2/runs/project/my-project/domain/development/cg9s54pksbjsdxlz2gmc
Run 'a0' completed successfully.
```
Click the link to go to your Union instance and see the run in the UI:

## Subpages
- **Getting started > Local setup**
- **Getting started > Running**
- **Getting started > Anatomy of a run**
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/getting-started/local-setup ===
# Local setup
In this section we will explain the options for configuring the `flyte` CLI and SDK to connect to your Union/Flyte instance.
Before proceeding, make sure you have completed the steps in **Getting started**.
You will need to have the `uv` tool and the `flyte` Python package installed.
## Setting up a configuration file
In **Getting started** we used the `flyte create config` command to create a configuration file at `./.flyte/config.yaml`.
```shell
flyte create config \
--endpoint my-org.my-company.com \
--project my-project \
--domain development \
--builder remote
```
The result of the above command would be the creation of a file called `./flyte/config.yaml` in your current working directory
with the following content:
```yaml
admin:
endpoint: dns:///my-org.my-company.com
image:
builder: remote
task:
domain: development
org: my-org
project: my-project
```
π‘ See full example using all available options
The example below creates a configuration file called `my-config.yaml` in the current working directory with all of the available options
```shell
flyte create config \
--endpoint my-org.my-company.com \
--insecure \
--builder remote \
--domain development \
--org my-org \
--project my-project \
--output my-config.yaml \
--force
```
See the **Flyte CLI > flyte > flyte create > flyte create config** section for details on the available parameters.
βΉοΈ Notes about the properties in the config file
**`admin` section**: contains the connection details for your Union/Flyte instance.
* `admin.endpoint` is the URL (always with `dns:///` prefix) of your Union/Flyte instance.
If your instance UI is found at https://my-org.my-company.com, the actual endpoint used in this file would be `dns:///my-org.my-company.com`.
* `admin.insecure` indicates whether to use an insecure connection (without TLS) to the Union/Flyte instance.
A setting of `true` is almost always only used for connecting to a local instance on your own machine.
**`image` section**: contains the configuration for building Docker images for your tasks.
* `image.builder` specifies the image builder to use for building Docker images for your tasks.
* For Union instances this is usually set to `remote`, which means that the images will be built on Union's infrastructure using the Union `ImageBuilder`.
* For Flyte OSS instances, `ImageBuilder` is not available, so this property must be set to `local`.
This means that the images will be built locally on your machine.
You need to have Docker installed and running for this to work.
See **Configure tasks > Container images > Image building** for details.
**`task` section**: contains the configuration for running tasks on your Union/Flyte instance.
* `task.domain` specifies the domain in which the tasks will run.
Domains are used to separate different environments, such as `development`, `staging`, and `production`.
* `task.org` specifies the organization in which the tasks will run. The organization is usually synonymous with the name of the Union instance you are using, which is usually the same as the first part of the `admin.endpoint` URL.
* `task.project` specifies the project in which the tasks will run. The project you specify here will be the default project to which tasks are deployed if no other project is specified. The project you specify must already exist on your Union/Flyte instance (it will not be auto-created on first deploy).
## Using the configuration file
You can use the configuration file either explicitly by referencing it directly from a CLI or Python command, or implicitly by placing it in a specific location or setting an environment variable.
### Specify a configuration file explicitly
When using the `flyte` CLI, you can specify the configuration file explicitly by using the `--config` or `-c` parameter.
You can explicitly specify the configuration file when running a `flyte` CLI command by using the `--config` parameter, like this:
```shell
flyte --config my-config.yaml run hello.py main
```
or just using the `-c` shorthand:
```shell
flyte -c my-config.yaml run hello.py main
```
When invoking flyte commands programmatically, you have to first initialize the Flyte SDK with the configuration file.
To initialize with an explicitly specified configuration file, use **Getting started > Local setup > `flyte.init_from_config`**:
```python
flyte.init_from_config("my-config.yaml")
```
Then you can continue with other `flyte` commands, such as running the main task:
```python
run = flyte.run(main)
```
### Use the configuration file implicitly
You can also use the configuration file implicitly by placing it in a specific location or setting an environment variable.
You can use the `flyte CLI` without an explicit `--config` like this:
```shell
flyte run hello.py main
```
You can also initializing the Flyte SDK programmatically without specifying a configuration file, like this:
```python
flyte.init_from_config()
```
In these cases, the SDK will search in the following order until it finds a configuration file:
* `./config.yaml` (i.e., in the current working directory).
* `./flyte/config.yaml` (i.e., in the `.flyte` directory in the current working directory).
* `UCTL_CONFIG` (a file pointed to by this environment variable).
* `FLYTECTL_CONFIG` (a file pointed to by this environment variable)
* `~/.union/config.yaml`
* `~/.flyte/config.yaml`
### Checking your configuration
You can check your current configuration by running the following command:
```shell
flyte get config
```
This will return the current configuration as a serialized Python object. For example
```shell
CLIConfig(
Config(
platform=PlatformConfig(endpoint='dns:///my-org.my-company.com', scopes=[]),
task=TaskConfig(org='my-org', project='my-project', domain='development'),
source=PosixPath('/Users/me/.flyte/config.yaml')
),
,
log_level=None,
insecure=None
)
```
## Inline configuration
### With `flyte` CLI
You can also use Flyte SDK with inline configuration parameters, without using a configuration file.
When using the `flyte` CLI, some parameters are specified after the top level command (i.e., `flyte`) while other are specified after the sub-command (for example, `run`).
For example, you can run a workflow using the following command:
```shell
flyte \
--endpoint my-org.my-company.com \
--org my-org \
run \
--domain development \
--project my-project
hello.py \
main
```
See the **Flyte CLI** for details.
When using the Flyte SDK programmatically, you can use the **Flyte SDK > Packages > flyte > Methods > init()** function to specify the backend endpoint and other parameters directly in your code.
### With `flyte` SDK
To initialize the Flyte SDK with inline parameters, you can use the **Flyte SDK > Packages > flyte > Methods > init()** function like this:
```python
flyte.init(
endpoint="dns:///my-org.my-company.com",
org="my-org",
project="my-project",
domain="development",
)
```
See the **Flyte SDK > Packages > flyte > Methods > init()** for details.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/getting-started/running ===
# Running
Flyte SDK lets you seamlessly switch between running your workflows locally on your machine and running them remotely on your Union/Flyte instance.
Furthermore, you perform these actions either programmatically from within Python code or from the command line using the `flyte` CLI.
## Running remotely
### From the command-line
To run your code on your Union/Flyte instance, you can use the `flyte run` command without the `--local` flag:
```shell
flyte run hello.py main
```
This deploys your code to the configured Union/Flyte instance and runs it immediately (Since no explicit `--config` is specified, the configuration found according to the **Getting started > Local setup > Using the configuration file > Use the configuration file implicitly** will be used).
### From Python
To run your workflow remotely from Python, use **Flyte SDK > Packages > flyte > Methods > run()** by itself, like this:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# ]
# main = "main"
# params = "name='World'"
# ///
# run_from_python.py
# {{docs-fragment all}}
import flyte
env = flyte.TaskEnvironment(name="hello_world")
@env.task
def main(name: str) -> str:
return f"Hello, {name}!"
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, name="World")
print(r.name)
print(r.url)
r.wait()
# {{/docs-fragment all}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/getting-started/running/run_from_python.py)
This is the approach we use throughout our examples in this guide.
We execute the script, thus invoking the `flyte.run()` function, with the top-level task as a parameter.
The `flyte.run()` function then deploys and runs the code in that file itself on your remote Union/Flyte instance.
## Running locally
### From the command-line
To run your code on your local machine, you can use the `flyte run` command with the `--local` flag:
```shell
flyte run --local hello.py main
```
### From Python
To run your workflow locally from Python, you chain **Getting started > Running > `flyte.with_runcontext()`** with **Flyte SDK > Packages > flyte > Methods > run()** and specify the run `mode="local"`, like this:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# ]
# main = "main"
# params = "name='World'"
# ///
# run_local_from_python.py
# {{docs-fragment all}}
import flyte
env = flyte.TaskEnvironment(name="hello_world")
@env.task
def main(name: str) -> str:
return f"Hello, {name}!"
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.with_runcontext(mode="local").run(main, name="World")
print(r.name)
print(r.url)
r.wait()
# {{/docs-fragment all}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/getting-started/running/run_local_from_python.py)
Running your workflow locally is useful for testing and debugging, as it allows you to run your code without deploying it to a remote instance.
It also lets you quickly iterate on your code without the overhead of deployment.
Obviously, if your code relies on remote resources or services, you will need to mock those in your local environment, or temporarily work around any missing functionality.
At the very least, local execution can be used to catch immediate syntax errors and other relatively simple issues before deploying your code to a remote instance.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/getting-started/anatomy-of-a-run ===
# Anatomy of a run
To understand how Flyte 2 works, it helps to establish a few definitions and concepts.
* **Workflow**: A collection of tasks linked by invocation, with a top-most task that is the entry point of the workflow.
We sometime refer to this as the "parent", "driver" or "top-most" task.
Unlike in Flyte 1, there is no explicit `@workflow` decorator; instead, the workflow is defined implicitly by the structure of the code.
Nonetheless, you will often see the assemblage of tasks referred to as a "workflow".
* **TaskEnvironment**: A `TaskEnvironment` object is the abstraction that defines the hardware and software environment in which one or more tasks are executed.
* The hardware environment is specified by parameters that define the type of compute resources (e.g., CPU, memory) allocated to the task.
* The software environment is specified by parameters that define the container image, including dependencies, required to run the task.
* **Task**: A Python function.
* Tasks are defined using the `@env.task` decorator, where the `env` refers to a `TaskEnvironment` object.
* Tasks can involve invoking helper functions as well as other tasks and assembling outputs from those invocations.
* **Run**: A run is the execution of a task directly initiated by a user and all its descendant tasks, considered together.
* **Action**: An action is the execution of a single task, considered independently. A run consists of one or more actions.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-configuration ===
# Configure tasks
As we saw in **Getting started**, you can run any Python function as a task in Flyte just by decorating it with `@env.task`.
This allows you to run your Python code in a distributed manner, with each function running in its own container.
Flyte manages the spinning up of the containers, the execution of the code, and the passing of data between the tasks.
The simplest possible case is a `TaskEnvironment` with only a `name` parameter, and an `env.task` decorator, with no parameters:
```
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def my_task(name:str) -> str:
return f"Hello {name}!"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/task_config.py)
> [!NOTE]
> Notice how the `TaskEnvironment` is assigned to the variable `env` and then that variable is
> used in the `@env.task`. This is what connects the `TaskEnvironment` to the task definition.
>
> In the following we will often use `@env.task` generically to refer to the decorator,
> but it is important to remember that it is actually a decorator attached to a specific
> `TaskEnvironment` object, and the `env` part can be any variable name you like.
This will run your task in the default container environment with default settings.
But, of course, one of the key advantages of Flyte is the ability to control the software environment, hardware environment, and other execution parameters for each task, right in your Python code.
In this section we will explore the various configuration options available for tasks in Flyte.
## Task configuration levels
Task configuration is done at three levels. From most general to most specific, they are:
* The `TaskEnvironment` level: setting parameters when defining the `TaskEnvironment` object.
* The `@env.task` decorator level: Setting parameters in the `@env.task` decorator when defining a task function.
* The task invocation level: Using the `task.override()` method when invoking task execution.
Each level has its own set of parameters, and some parameters are shared across levels.
For shared parameters, the more specific level will override the more general one.
### Example
Here is an example of how these levels work together, showing each level with all available parameters:
```
# Level 1: TaskEnvironment - Base configuration
env_2 = flyte.TaskEnvironment(
name="data_processing_env",
image=flyte.Image.from_debian_base(),
resources=flyte.Resources(cpu=1, memory="512Mi"),
env_vars={"MY_VAR": "value"},
# secrets=flyte.Secret(key="openapi_key", as_env_var="MY_API_KEY"),
cache="disable",
# pod_template=my_pod_template,
# reusable=flyte.ReusePolicy(replicas=2, idle_ttl=300),
depends_on=[another_env],
description="Data processing task environment",
# plugin_config=my_plugin_config
)
# Level 2: Decorator - Override some environment settings
@env_2.task(
short_name="process",
# secrets=flyte.Secret(key="openapi_key", as_env_var="MY_API_KEY_2"),
cache="auto",
# pod_template=my_pod_template,
report=True,
max_inline_io_bytes=100 * 1024,
retries=3,
timeout=60,
docs="This task processes data and generates a report."
)
async def process_data(data_path: str) -> str:
return f"Processed {data_path}"
@env_2.task
async def invoke_process_data() -> str:
result = await process_data.override(
resources=flyte.Resources(cpu=4, memory="2Gi"),
env_vars={"MY_VAR": "new_value"},
# secrets=flyte.Secret(key="openapi_key", as_env_var="MY_API_KEY_3"),
cache="auto",
max_inline_io_bytes=100 * 1024,
retries=3,
timeout=60
)("input.csv")
return result
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/task_config.py)
### Parameter interaction
Here is an overview of all task configuration parameters available at each level and how they interact:
| Parameter | `TaskEnvironment` | `@env.task` decorator | `override` on task invocation |
|-------------------------|--------------------|----------------------------|-------------------------------|
| **name** | β Yes (required) | β No | β No |
| **short_name** | β No | β Yes | β Yes |
| **image** | β Yes | β No | β No |
| **resources** | β Yes | β No | β Yes (if not `reusable`) |
| **env_vars** | β Yes | β No | β Yes (if not `reusable`) |
| **secrets** | β Yes | β No | β Yes (if not `reusable`) |
| **cache** | β Yes | β Yes | β Yes |
| **pod_template** | β Yes | β Yes | β Yes |
| **reusable** | β Yes | β No | β Yes |
| **depends_on** | β Yes | β No | β No |
| **description** | β Yes | β No | β No |
| **plugin_config** | β Yes | β No | β No |
| **report** | β No | β Yes | β No |
| **max_inline_io_bytes** | β No | β Yes | β Yes |
| **retries** | β No | β Yes | β Yes |
| **timeout** | β No | β Yes | β Yes |
| **triggers** | β No | β Yes | β No |
| **interruptible** | β Yes | β Yes | β Yes |
| **queue** | β Yes | β Yes | β Yes |
| **docs** | β No | β Yes | β No |
## Task configuration parameters
The full set of parameters available for configuring a task environment, task definition, and task invocation are:
### `name`
* Type: `str` (required)
* Defines the name of the `TaskEnvironment`.
Since it specifies the name *of the environment*, it cannot, logically, be overridden at the `@env.task` decorator or the `task.override()` invocation level.
It is used in conjunction with the name of each `@env.task` function to define the fully-qualified name of the task.
The fully qualified name is always the `TaskEnvironment` name (the one above) followed by a period and then the task function name (the name of the Python function being decorated).
For example:
```
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def my_task(name:str) -> str:
return f"Hello {name}!"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/task_config.py)
Here, the name of the TaskEnvironment is `my_env` and the fully qualified name of the task is `my_env.my_task`.
The `TaskEnvironment` name and fully qualified name of a task name are both fixed and cannot be overridden.
### `short_name`
* Type: `str` (required)
* Defines the short name of the task or action (the execution of a task).
Since it specifies the name *of the task*, it is not, logically, available to be set at the ``TaskEnvironment` level.
By default, the short name of a task is the name of the task function (the name of the Python function being decorated).
The short name is used, for example, in parts of the UI.
Overriding it does not change the fully qualified name of the task.
### `image`
* Type: `Union[str, Image, Literal['auto']]`
* Specifies the Docker image to use for the task container.
Can be a URL reference to a Docker image, an **Configure tasks > `Image` object**, or the string `auto`.
If set to `auto`, or if this parameter is not set, the [default image]() will be used.
* Only settable at the `TaskEnvironment` level.
* See **Configure tasks > Container images**.
### `resources`
* Type: `Optional[Resources]`
* Specifies the compute resources, such as CPU and Memory, required by the task environment using a
**Configure tasks > `Resources`** object.
* Can be set at the `TaskEnvironment` level and overridden at the `task.override()` invocation level
(but only if `reuseable` is not in effect).
* See **Configure tasks > Resources**.
### `env_vars`
* Type: `Optional[Dict[str, str]]`
* A dictionary of environment variables to be made available in the task container.
These variables can be used to configure the task at runtime, such as setting API keys or other configuration values.
### `secrets`
* Type: `Optional[SecretRequest]` where `SecretRequest` is an alias for `Union[str, Secret, List[str | Secret]]`
* The secrets to be made available in the task container.
* Can be set at the `TaskEnvironment` level and overridden at the `task.override()` invocation level, but only if `reuseable` is not in effect.
* See **Configure tasks > Secrets** and the API docs for the **Configure tasks > `Secret` object**.
### `cache`
* Type: `Union[CacheRequest]` where `CacheRequest` is an alias for `Literal["auto", "override", "disable", "enabled"] | Cache`.
* Specifies the caching policy to be used for this task.
* Can be set at the `TaskEnvironment` level and overridden at the `@env.task` decorator level
and at the `task.override()` invocation level.
* See **Configure tasks > Caching**.
### `pod_template`
* Type: `Optional[Union[str, kubernetes.client.V1PodTemplate]]`
* A pod template that defines the Kubernetes pod configuration for the task.
A string reference to a named template or a `kubernetes.client.V1PodTemplate` object.
* Can be set at the `TaskEnvironment` level and overridden at the `@env.task` decorator level and the `task.override()` invocation level.
* See **Configure tasks > Pod templates**.
### `reusable`
* Type: `ReusePolicy | None`
* A `ReusePolicy` that defines whether the task environment can be reused.
If set, the task environment will be reused across multiple task invocations.
* When a `TaskEnvironment` has `reusable` set, then `resources`, `env_vars`, and `secrets` can only be overridden in `task.override()`
if accompanied by an explicit `reusable="off"` in the same `task.override()` invocation.
Additionally, `secrets` can only be overridden at the `@env.task` decorator level if the `TaskEnvironment` (`env`) does not have `reusable` set.
* See **Configure tasks > Reusable containers** and the API docs for the **Configure tasks > `ReusePolicy` object**.
### `depends_on`
* Type: `List[Environment]`
* A list of **Configure tasks > `Environment`** objects that this `TaskEnvironment` depends on.
When deploying this `TaskEnvironment`, the system will ensure that any dependencies of the listed `Environment`s are also available.
This is useful when you have a set of task environments that depend on each other.
* Can only be set at the `TaskEnvironment` level, not at the `@env.task` decorator level or the `task.override()` invocation level.
* See **Configure tasks > Multiple environments**
### `description`
* Type: `Optional[str]`
* A description of the task environment.
This can be used to provide additional context about the task environment, such as its purpose or usage.
* Can only be set at the `TaskEnvironment` level, not at the `@env.task` decorator level
or the `task.override()` invocation level.
### `plugin_config`
* Type: `Optional[Any]`
* Additional configuration for plugins that can be used with the task environment.
This can include settings for specific plugins that are used in the task environment.
* Can only be set at the `TaskEnvironment` level, not at the `@env.task` decorator level
or the `task.override()` invocation level.
### `report`
* Type: `bool`
* Whether to generate the HTML report for the task.
If set to `True`, the task will generate an HTML report that can be viewed in the Flyte UI.
* Can only be set at the `@env.task` decorator level,
not at the `TaskEnvironment` level or the `task.override()` invocation level.
* See **Build tasks > Reports**.
### `max_inline_io_bytes`
* Type: `int`
* Maximum allowed size (in bytes) for all inputs and outputs passed directly to the task
(e.g., primitives, strings, dictionaries).
Does not apply to **Build tasks > Files and directories**, or **Build tasks > Dataclasses and structures** (since these are passed by reference).
* Can be set at the `@env.task` decorator level and overridden at the `task.override()` invocation level.
If not set, the default value is `MAX_INLINE_IO_BYTES` (which is 100 MiB).
### `retries`
* Type: `Union[int, RetryStrategy]`
* The number of retries for the task, or a `RetryStrategy` object that defines the retry behavior.
If set to `0`, no retries will be attempted.
* Can be set at the `@env.task` decorator level and overridden at the `task.override()` invocation level.
* See **Configure tasks > Retries and timeouts**.
### `timeout`
* Type: `Union[timedelta, int]`
* The timeout for the task, either as a `timedelta` object or an integer representing seconds.
If set to `0`, no timeout will be applied.
* Can be set at the `@env.task` decorator level and overridden at the `task.override()` invocation level.
* See **Configure tasks > Retries and timeouts**.
### `triggers`
* Type: `Tuple[Trigger, ...] | Trigger`
* A trigger or tuple of triggers that define when the task should be executed.
* Can only be set at the `@env.task` decorator level. It cannot be overridden.
* See **Configure tasks > Triggers**.
### `interruptible`
* Type: `bool`
* Specifies whether the task is interruptible.
If set to `True`, the task can be scheduled on a spot instance, otherwise it can only be scheduled on on-demand instances.
* Can be set at the `TaskEnvironment` level and overridden at the `@env.task` decorator level and at the `task.override()` invocation level.
### `queue`
* Type: `Optional[str]`
* Specifies the queue to which the task should be directed, where the queue is identified by its name.
If set to `None`, the default queue will be used.
Queues serve to point to a specific partitions of your compute infrastructure (for example, a specific cluster in multi-cluster setup).
They are configured as part of your Union/Flyte deployment.
* Can be set at the `TaskEnvironment` level and overridden at the `@env.task` decorator level
and at the `task.override()` invocation level.
### `docs`
* Type: `Optional[Documentation]`
* Documentation for the task, including usage examples and explanations of the task's behavior.
* Can only be set at the `@env.task` decorator level. It cannot be overridden.
## Subpages
- **Configure tasks > Container images**
- **Configure tasks > Resources**
- **Configure tasks > Secrets**
- **Configure tasks > Caching**
- **Configure tasks > Reusable containers**
- **Configure tasks > Pod templates**
- **Configure tasks > Multiple environments**
- **Configure tasks > Retries and timeouts**
- **Configure tasks > Triggers**
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-configuration/container-images ===
# Container images
The `image` parameter of the **Configure tasks > Container images > `TaskEnvironment`** is used to specify a container image.
Every task defined using that `TaskEnvironment` will run in a container based on that image.
If a `TaskEnvironment` does not specify an `image`, it will use the default Flyte image ([`ghcr.io/flyteorg/flyte:py{python-version}-v{flyte_version}`](https://github.com/orgs/flyteorg/packages/container/package/flyte)).
## Specifying your own image directly
You can directly reference an image by URL in the `image` parameter, like this:
```python
env = flyte.TaskEnvironment(
name="my_task_env",
image="docker.io/myorg/myimage:mytag"
)
```
This works well if you have a pre-built image available in a public registry like Docker Hub or in a private registry that your Union/Flyte instance can access.
## Specifying your own image with the `flyte.Image` object
You can also construct an image programmatically using the `flyte.Image` object.
The `flyte.Image` object provides a fluent interface for building container images with specific dependencies.
You start building your image with on of the `from_` methods:
* **Configure tasks > Container images > `Image.from_base()`**: Start from a pre-built image (Note: The image should be accessible to the imagebuilder).
* **Configure tasks > Container images > `Image.from_debian_base()`**: Start from a [Debian](https://www.debian.org/) based base image, that contains flyte already.
* **Configure tasks > Container images > `Image.from_uv_script()`**: Start with a new image build from a [uv script](https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies), slower but easier.
You can then layer on additional components using the `with_` methods:
* **Configure tasks > Container images > `Image.with_apt_packages()`**: Add Debian packages to the image (e.g. apt-get ...).
* **Configure tasks > Container images > `Image.with_commands()`**: Add commands to run in the image (e.g. chmod a+x ... / curl ... / wget).
* **Configure tasks > Container images > `Image.with_dockerignore()`**: Specify a `.dockerignore` file that will be respected durin image build.
* **Configure tasks > Container images > `Image.with_env_vars()`**: Set environment variables in the image.
* **Configure tasks > Container images > `Image.with_pip_packages()`**: Add Python packages to the image (installed via uv pip install ...)
* **Configure tasks > Container images > `Image.with_requirements()`**: Specify a requirements.txt file (all packages will be installed).
* **Configure tasks > Container images > `Image.with_source_file()`**: Specify a source file to include in the image (the file will be copied).
* **Configure tasks > Container images > `Image.with_source_folder()`**: Specify a source folder to include in the image (entire folder will be copied).
* **Configure tasks > Container images > `Image.with_uv_project()`**: Use this with `pyproject.toml` or `uv.lock` based projects.
* **Configure tasks > Container images > `Image.with_poetry_project()`**: Create a new image with the specified `poetry.lock`
* **Configure tasks > Container images > `Image.with_workdir()`**: Specify the working directory for the image.
You can also specify an image in one shot (with no possibility of layering) with:
* **Configure tasks > Container images > `Image.from_dockerfile()`**: Build the final image from a single Dockerfile. (Useful incase of an existing dockerfile).
Additionally, the `Image` class provides:
* **Configure tasks > Container images > `Image.clone()`**: Clone an existing image. (Note: Every operation with_* always clones, every image is immutable. Clone is useful if you need to make a new named image).
* **Configure tasks > Container images > `Image.validate()`**: Validate the image configuration.
* **Configure tasks > Container images > `Image.with_local_v2()`**: Does not add a layer, instead it overrides any existing builder configuration and builds the image locally. See **Configure tasks > Container images > Image building** for more details.
Here are some examples of the most common patterns for building images with `flyte.Image`.
## Example: Defining a custom image with `Image.from_debian_base`
The `Image.from_debian_base()` provides the default Flyte image as the base.
This image is itself based on the official Python Docker image (specifically `python:{version}-slim-bookworm`) with the addition of the Flyte SDK pre-installed.
Starting there, you can layer additional features onto your image.
For example:
```python
import flyte
import numpy as np
# Define the task environment
env = flyte.TaskEnvironment(
name="my_env",
image = (
flyte.Image.from_debian_base(
name="my-image",
python_version=(3, 13)
# registry="registry.example.com/my-org" # Only needed for local builds
)
.with_apt_packages("libopenblas-dev")
.with_pip_packages("numpy")
.with_env_vars({"OMP_NUM_THREADS": "4"})
)
)
@env.task
def main(x_list: list[int]) -> float:
arr = np.array(x_list)
return float(np.mean(arr))
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, x_list=list(range(10)))
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/container-images/from_debian_base.py)
> [!NOTE]
> The `registry` parameter is only needed if you are building the image locally. It is not required when using the Union backend `ImageBuilder`.
> See **Configure tasks > Container images > Image building** for more details.
> [!NOTE]
> Images built with `flyte.Image.from_debian_base()` do not include CA certificates by default, which can cause TLS
> validation errors and block access to HTTPS-based storage such as Amazon S3. Libraries like Polars (e.g., `polars.scan_parquet()`) are particularly affected.
> **Solution:** Add `"ca-certificates"` using `.with_apt_packages()` in your image definition.
## Example: Defining an image based on uv script metadata
Another common technique for defining an image is to use [`uv` inline script metadata](https://docs.astral.sh/uv/guides/scripts/#declaring-script-dependencies) to specify your dependencies right in your Python file and then use the `flyte.Image.from_uv_script()` method to create a `flyte.Image` object.
The `from_uv_script` method starts with the default Flyte image and adds the dependencies specified in the `uv` metadata.
For example:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "numpy"
# ]
# main = "main"
# params = "x_list=[1,2,3,4,5,6,7,8,9,10]"
# ///
import flyte
import numpy as np
env = flyte.TaskEnvironment(
name="my_env",
image=flyte.Image.from_uv_script(
__file__,
name="my-image"
# registry="registry.example.com/my-org" # Only needed for local builds
)
)
@env.task
def main(x_list: list[int]) -> float:
arr = np.array(x_list)
return float(np.mean(arr))
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, x_list=list(range(10)))
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/container-images/from_uv_script.py)
The advantage of this approach is that the dependencies used when running a script locally and when running it on the Flyte/Union backend are always the same (as long as you use `uv` to run your scripts locally).
This means you can develop and test your scripts in a consistent environment, reducing the chances of encountering issues when deploying to the backend.
In the above example you can see how to use `flyte.init_from_config()` for remote runs and `flyte.init()` for local runs.
Uncomment the `flyte.init()` line (and comment out `flyte.init_from_config()`) to enable local runs.
Do the opposite to enable remote runs.
> [!NOTE]
> When using `uv` metadata in this way, be sure to include the `flyte` package in your `uv` script dependencies.
> This will ensure that `flyte` is installed when running the script locally using `uv run`.
> When running on the Flyte/Union backend, the `flyte` package from the uv script dependencies will overwrite the one included automatically from the default Flyte image.
## Image building
There are two ways that the image can be built:
* If you are running a Flyte OSS instance then the image will be built locally on your machine and pushed to the container registry you specified in the `Image` definition.
* If you are running a Union instance, the image can be built locally, as with Flyte OSS, or using the Union `ImageBuilder`, which runs remotely on Union's infrastructure.
### Configuring the `builder`
**Getting started > Local setup**, we discussed the `image.builder` property in the `config.yaml`.
For Flyte OSS instances, this property must be set to `local`.
For Union instances, this property can be set to `remote` to use the Union `ImageBuilder`, or `local` to build the image locally on your machine.
### Local image building
When `image.builder` in the `config.yaml` is set to `local`, `flyte.run()` does the following:
* Builds the Docker image using your local Docker installation, installing the dependencies specified in the `uv` inline script metadata.
* Pushes the image to the container registry you specified.
* Deploys your code to the backend.
* Kicks off the execution of your workflow
* Before the task that uses your custom image is executed, the backend pulls the image from the registry to set up the container.
> [!NOTE]
> Above, we used `registry="ghcr.io/my_gh_org"`.
>
> Be sure to change `ghcr.io/my_gh_org` to the URL of your actual container registry.
You must ensure that:
* Docker is running on your local machine.
* You have successfully run `docker login` to that registry from your local machine (For example GitHub uses the syntax `echo $GITHUB_TOKEN | docker login ghcr.io -u USERNAME --password-stdin`)
* Your Union/Flyte installation has read access to that registry.
> [!NOTE]
> If you are using the GitHub container registry (`ghcr.io`)
> note that images pushed there are private by default.
> You may need to go to the image URI, click **Package Settings**, and change the visibility to public in order to access the image.
>
> Other registries (such as Docker Hub) require that you pre-create the image repository before pushing the image.
> In that case you can set it to public when you create it.
>
> Public images are on the public internet and should only be used for testing purposes.
> Do not place proprietary code in public images.
### Remote `ImageBuilder`
`ImageBuilder` is a service provided by Union that builds container images on Union's infrastructure and provides an internal container registry for storing the built images.
When `image.builder` in the `config.yaml` is set to `remote` (and you are running Union.ai), `flyte.run()` does the following:
* Builds the Docker image on your Union instance with `ImageBuilder`.
* Pushes the image to a registry
* If you did not specify a `registry` in the `Image` definition, it pushes to the internal registry in your Union instance.
* If you did specify a `registry`, it pushes to that registry. Be sure to also set the `registry_secret` parameter in the `Image` definition to enable `ImageBuilder` to authenticate to that registry (see **Configure tasks > Container images > Image building > Remote `ImageBuilder` > ImageBuilder with external registries**).
* Deploys your code to the backend.
* Kicks off the execution of your workflow.
* Before the task that uses your custom image is executed, the backend pulls the image from the registry to set up the container.
There is no set up of Docker nor any other local configuration required on your part.
#### ImageBuilder with external registries
If you are want to push the images built by `ImageBuilder` to an external registry, you can do this by setting the `registry` parameter in the `Image` object.
You will also need to set the `registry_secret` parameter to provide the secret needed to push and pull images to the private registry.
For example:
```python
# Add registry credentials so the Union remote builder can pull the base image
# and push the resulting image to your private registry.
image=flyte.Image.from_debian_base(
name="my-image",
base_image="registry.example.com/my-org/my-private-image:latest",
registry="registry.example.com/my-org"
registry_secret="my-secret"
)
# Reference the same secret in the TaskEnvironment so Flyte can pull the image at runtime.
env = flyte.TaskEnvironment(
name="my_task_env",
image=image,
secrets="my-secret"
)
```
The value of the `registry_secret` parameter must be the name of a Flyte secret of type `image_pull` that contains the credentials needed to access the private registry. It must match the name specified in the `secrets` parameter of the `TaskEnvironment` so that Flyte can use it to pull the image at runtime.
To create an `image_pull` secret for the remote builder and the task environment, run the following command:
```shell
$ flyte create secret --type image_pull my-secret --from-file ~/.docker/config.json
```
The format of this secret matches the standard Kubernetes [image pull secret](https://kubernetes.io/docs/tasks/configure-pod-container/pull-image-private-registry/#log-in-to-docker-hub), and should look like this:
```json
{
"auths": {
"registry.example.com": {
"auth": "base64-encoded-auth"
}
}
}
```
> [!NOTE]
> The `auth` field contains the base64-encoded credentials for your registry (username and password or token).
### Install private PyPI packages
To install Python packages from a private PyPI index (for example, from GitHub), you can mount a secret to the image layer.
This allows your build to authenticate securely during dependency installation.
For example:
```python
private_package = "git+https://$GITHUB_PAT@github.com/pingsutw/flytex.git@2e20a2acebfc3877d84af643fdd768edea41d533"
image = (
Image.from_debian_base()
.with_apt_packages("git")
.with_pip_packages(private_package, pre=True, secret_mounts=Secret("GITHUB_PAT"))
)
```
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-configuration/resources ===
# Resources
Task resources specify the computational limits and requests (CPU, memory, GPU, storage) that will be allocated to each task's container during execution.
To specify resource requirements for your task, instantiate a `Resources` object with the desired parameters and assign it to either
the `resources` parameter of the `TaskEnvironment` or the `resources` parameter of the `override` function (for invocation overrides).
Every task defined using that `TaskEnvironment` will run with the specified resources.
If a specific task has its own `resources` defined in the decorator, it will override the environment's resources for that task only.
If neither `TaskEnvironment` nor the task decorator specifies `resources`, the default resource allocation will be used.
## Resources dataclass
The `Resources` dataclass provides the following initialization parameters:
```python
resources = flyte.Resources(
cpu: Union[int, float, str, Tuple[Union[int, float, str], Union[int, float, str]], None] = None,
memory: Union[str, Tuple[str, str], None] = None,
gpu: Union[str, int, flyte.Device, None] = None,
disk: Union[str, None] = None,
shm: Union[str, Literal["auto"], None] = None
)
```
Each parameter is optional and allows you to specify different types of resources:
- **`cpu`**: CPU allocation - can be a number, string, or tuple for request/limit ranges (e.g., `2` or `(2, 4)`).
- **`memory`**: Memory allocation - string with units (e.g., `"4Gi"`) or tuple for ranges.
- **`gpu`**: GPU allocation - accelerator string (e.g., `"A100:2"`), count, or `Device` (a **Configure tasks > Resources > GPU resources**, **Configure tasks > Resources > TPU resources** or **Configure tasks > Resources > Custom device specifications**).
- **`disk`**: Ephemeral storage - string with units (e.g., `"10Gi"`).
- **`shm`**: Shared memory - string with units or `"auto"` for automatic sizing (e.g., `"8Gi"` or `"auto"`).
## Examples
### Usage in TaskEnvironment
Here's a complete example of defining a TaskEnvironment with resource specifications for a machine learning training workload:
```
import flyte
# Define a TaskEnvironment for ML training tasks
env = flyte.TaskEnvironment(
name="ml-training",
resources=flyte.Resources(
cpu=("2", "4"), # Request 2 cores, allow up to 4 cores for scaling
memory=("2Gi", "12Gi"), # Request 2 GiB, allow up to 12 GiB for large datasets
disk="50Gi", # 50 GiB ephemeral storage for checkpoints
shm="8Gi" # 8 GiB shared memory for efficient data loading
)
)
# Use the environment for tasks
@env.task
async def train_model(dataset_path: str) -> str:
# This task will run with flexible resource allocation
return "model trained"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/resources/resources.py)
### Usage in a task-specific override
```
# Demonstrate resource override at task invocation level
@env.task
async def heavy_training_task() -> str:
return "heavy model trained with overridden resources"
@env.task
async def main():
# Task using environment-level resources
result = await train_model("data.csv")
print(result)
# Task with overridden resources at invocation time
result = await heavy_training_task.override(
resources=flyte.Resources(
cpu="4",
memory="24Gi",
disk="100Gi",
shm="16Gi"
)
)()
print(result)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/resources/resources.py)
## Resource types
### CPU resources
CPU can be specified in several formats:
```python
# String formats (Kubernetes-style)
flyte.Resources(cpu="500m") # 500 milliCPU (0.5 cores)
flyte.Resources(cpu="2") # 2 CPU cores
flyte.Resources(cpu="1.5") # 1.5 CPU cores
# Numeric formats
flyte.Resources(cpu=1) # 1 CPU core
flyte.Resources(cpu=0.5) # 0.5 CPU cores
# Request and limit ranges
flyte.Resources(cpu=("1", "2")) # Request 1 core, limit to 2 cores
flyte.Resources(cpu=(1, 4)) # Request 1 core, limit to 4 cores
```
### Memory resources
Memory specifications follow Kubernetes conventions:
```python
# Standard memory units
flyte.Resources(memory="512Mi") # 512 MiB
flyte.Resources(memory="1Gi") # 1 GiB
flyte.Resources(memory="2Gi") # 2 GiB
flyte.Resources(memory="500M") # 500 MB (decimal)
flyte.Resources(memory="1G") # 1 GB (decimal)
# Request and limit ranges
flyte.Resources(memory=("1Gi", "4Gi")) # Request 1 GiB, limit to 4 GiB
```
### GPU resources
Flyte supports various GPU types and configurations:
#### Simple GPU allocation
```python
# Basic GPU count
flyte.Resources(gpu=1) # 1 GPU (any available type)
flyte.Resources(gpu=4) # 4 GPUs
# Specific GPU types with quantity
flyte.Resources(gpu="T4:1") # 1 NVIDIA T4 GPU
flyte.Resources(gpu="A100:2") # 2 NVIDIA A100 GPUs
flyte.Resources(gpu="H100:8") # 8 NVIDIA H100 GPUs
```
#### Advanced GPU configuration
You can also use the `GPU` helper class for more detailed configurations:
```python
# Using the GPU helper function
gpu_config = flyte.GPU(device="A100", quantity=2)
flyte.Resources(gpu=gpu_config)
# GPU with memory partitioning (A100 only)
partitioned_gpu = flyte.GPU(
device="A100",
quantity=1,
partition="1g.5gb" # 1/7th of A100 with 5GB memory
)
flyte.Resources(gpu=partitioned_gpu)
# A100 80GB with partitioning
large_partition = flyte.GPU(
device="A100 80G",
quantity=1,
partition="7g.80gb" # Full A100 80GB
)
flyte.Resources(gpu=large_partition)
```
#### Supported GPU types
- **T4**: Entry-level training and inference
- **L4**: Optimized for AI inference
- **L40s**: High-performance compute
- **A100**: High-end training and inference (40GB)
- **A100 80G**: High-end training with more memory (80GB)
- **H100**: Latest generation, highest performance
### Custom device specifications
You can also define custom devices if your infrastructure supports them:
```python
# Custom device configuration
custom_device = flyte.Device(
device="custom_accelerator",
quantity=2,
partition="large"
)
resources = flyte.Resources(gpu=custom_device)
```
### TPU resources
For Google Cloud TPU workloads you can specify TPU resources using the `TPU` helper class:
```python
# TPU v5p configuration
tpu_config = flyte.TPU(device="V5P", partition="2x2x1")
flyte.Resources(gpu=tpu_config) # Note: TPUs use the gpu parameter
# TPU v6e configuration
tpu_v6e = flyte.TPU(device="V6E", partition="4x4")
flyte.Resources(gpu=tpu_v6e)
```
### Storage resources
Flyte provides two types of storage resources for tasks: ephemeral disk storage and shared memory.
These resources are essential for tasks that need temporary storage for processing data, caching intermediate results, or sharing data between processes.
#### Disk storage
Ephemeral disk storage provides temporary space for your tasks to store intermediate files, downloaded datasets, model checkpoints, and other temporary data. This storage is automatically cleaned up when the task completes.
```python
flyte.Resources(disk="10Gi") # 10 GiB ephemeral storage
flyte.Resources(disk="100Gi") # 100 GiB ephemeral storage
flyte.Resources(disk="1Ti") # 1 TiB for large-scale data processing
# Common use cases
flyte.Resources(disk="50Gi") # ML model training with checkpoints
flyte.Resources(disk="200Gi") # Large dataset preprocessing
flyte.Resources(disk="500Gi") # Video/image processing workflows
```
#### Shared memory
Shared memory (`/dev/shm`) is a high-performance, RAM-based storage area that can be shared between processes within the same container. It's particularly useful for machine learning workflows that need fast data loading and inter-process communication.
```python
flyte.Resources(shm="1Gi") # 1 GiB shared memory (/dev/shm)
flyte.Resources(shm="auto") # Auto-sized shared memory
flyte.Resources(shm="16Gi") # Large shared memory for distributed training
```
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-configuration/secrets ===
# Secrets
Flyte secrets enable you to securely store and manage sensitive information, such as API keys, passwords, and other credentials.
Secrets reside in a secret store on the data plane of your Union/Flyte backend.
You can create, list, and delete secrets in the store using the Flyte CLI or SDK.
Secrets in the store can be accessed and used within your workflow tasks, without exposing any cleartext values in your code.
## Creating a literal string secret
You can create a secret using the **Flyte CLI > flyte > flyte create > flyte create secret** command like this:
```shell
flyte create secret MY_SECRET_KEY my_secret_value
```
This will create a secret called `MY_SECRET_KEY` with the value `my_secret_value`.
This secret will be scoped to your entire organization.
It will be available across all projects and domains in your organization.
See the **Configure tasks > Secrets > Scoping secrets** section below for more details.
See **Configure tasks > Secrets > Using a literal string secret** for how to access the secret in your task code.
## Creating a file secret
You can also create a secret by specifying a local file:
```shell
flyte create secret MY_SECRET_KEY --from-file /local/path/to/my_secret_file
```
In this case, when accessing the secret in your task code, you will need to **Configure tasks > Secrets > Using a file secret**.
## Scoping secrets
When you create a secret without specifying a project or domain, as we did above, the secret is scoped to the organization level.
This means that the secret will be available across all projects and domains in the organization.
You can optionally specify either or both of the `--project` and `--domain` flags to restrict the scope of the secret to:
* A specific project (across all domains)
* A specific domain (across all project)
* A specific project and a specific domain.
For example, to create a secret that it is only available in `my_project/development`, you would execute the following command:
```shell
flyte create secret --project my_project --domain development MY_SECRET_KEY my_secret_value
```
## Listing secrets
You can list existing secrets with the **Flyte CLI > flyte > flyte get > flyte get secret** command.
For example, the following command will list all secrets in the organization:
```shell
$ flyte get secret
```
Specifying either or both of the `--project` and `--domain` flags will list the secrets that are **only** available in that project and/or domain.
For example, to list the secrets that are only available in `my_project` and domain `development`, you would run:
```shell
flyte get secret --project my_project --domain development
```
## Deleting secrets
To delete a secret, use the **Flyte CLI > flyte > flyte delete > flyte delete secret** command:
```shell
flyte delete secret MY_SECRET_KEY
```
## Using a literal string secret
To use a literal string secret, specify it in the `TaskEnvironment` along with the name of the environment variable into which it will be injected.
You can then access it using `os.getenv()` in your task code.
For example:
```
env_1 = flyte.TaskEnvironment(
name="env_1",
secrets=[
flyte.Secret(key="my_secret", as_env_var="MY_SECRET_ENV_VAR"),
]
)
@env_1.task
def task_1():
my_secret_value = os.getenv("MY_SECRET_ENV_VAR")
print(f"My secret value is: {my_secret_value}")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/secrets/secrets.py)
## Using a file secret
To use a file secret, specify it in the `TaskEnvironment` along with the `mount="/etc/flyte/secrets"` argument (with that precise value).
The file will be mounted at `/etc/flyte/secrets/`.
For example:
```
env_2 = flyte.TaskEnvironment(
name="env_2",
secrets=[
flyte.Secret(key="my_secret", mount="/etc/flyte/secrets"),
]
)
@env_2.task
def task_2():
with open("/etc/flyte/secrets/my_secret", "r") as f:
my_secret_file_content = f.read()
print(f"My secret file content is: {my_secret_file_content}")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/secrets/secrets.py)
> [!NOTE]
> Currently, to access a file secret you must specify a `mount` parameter value of `"/etc/flyte/secrets"`.
> This fixed path is the directory in which the secret file will be placed.
> The name of the secret file will be equal to the key of the secret.
> [!NOTE]
> A `TaskEnvironment` can only access a secret if the scope of the secret includes the project and domain where the `TaskEnvironment` is deployed.
> [!WARNING]
> Do not return secret values from tasks, as this will expose secrets to the control plane.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-configuration/caching ===
# Caching
Flyte 2 provides intelligent **task output caching** that automatically avoids redundant computation by reusing previously computed task results.
> [!NOTE]
> Caching works at the task level and caches complete task outputs.
> For function-level checkpointing and resumption *within tasks*, see **Build tasks > Traces**.
## Overview
By default, caching is disabled.
If caching is enabled for a task, then Flyte determines a **cache key** for the task.
The key is composed of the following:
* Final inputs: The set of inputs after removing any specified in the `ignored_inputs`.
* Task name: The fully-qualified name of the task.
* Interface hash: A hash of the task's input and output types.
* Cache version: The cache version string.
If the cache behavior is set to `"auto"`, the cache version is automatically generated using a hash of the task's source code (or according to the custom policy if one is specified).
If the cache behavior is set to `"override"`, the cache version can be specified explicitly using the `version_override` parameter.
When the task runs, Flyte checks if a cache entry exists for the key.
If found, the cached result is returned immediately instead of re-executing the task.
## Basic caching usage
Flyte 2 supports three main cache behaviors:
### `"auto"` - Automatic versioning
```
@env.task(cache=flyte.Cache(behavior="auto"))
async def auto_versioned_task(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
With `behavior="auto"`, the cache version is automatically generated based on the function's source code.
If you change the function implementation, the cache is automatically invalidated.
- **When to use**: Development and most production scenarios.
- **Cache invalidation**: Automatic when function code changes.
- **Benefits**: Zero-maintenance caching that "just works".
You can also use the direct string shorthand:
```
@env.task(cache="auto")
async def auto_versioned_task_2(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
### `"override"`
With `behavior="override"`, you can specify a custom cache key in the `version_override` parameter.
Since the cache key is fixed as part of the code, it can be manually changed when you need to invalidate the cache.
```
@env.task(cache=flyte.Cache(behavior="override", version_override="v1.2"))
async def manually_versioned_task(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
- **When to use**: When you need explicit control over cache invalidation.
- **Cache invalidation**: Manual, by changing `version_override`.
- **Benefits**: Stable caching across code changes that don't affect logic.
### `"disable"` - No caching
To explicitly disable caching, use the `"disable"` behavior.
**This is the default behavior.**
```
@env.task(cache=flyte.Cache(behavior="disable"))
async def always_fresh_task(data: str) -> str:
return get_current_timestamp() + await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
- **When to use**: Non-deterministic functions, side effects, or always-fresh data.
- **Cache invalidation**: N/A - never cached.
- **Benefits**: Ensures execution every time.
You can also use the direct string shorthand:
```
@env.task(cache="disable")
async def always_fresh_task_2(data: str) -> str:
return get_current_timestamp() + await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
## Advanced caching configuration
### Ignoring specific inputs
Sometimes you want to cache based on some inputs but not others:
```
@env.task(cache=flyte.Cache(behavior="auto", ignored_inputs=("debug_flag",)))
async def selective_caching(data: str, debug_flag: bool) -> str:
if debug_flag:
print(f"Debug: transforming {data}")
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
**This is useful for**:
- Debug flags that don't affect computation
- Logging levels or output formats
- Metadata that doesn't impact results
### Cache serialization
Cache serialization ensures that only one instance of a task runs at a time for identical inputs:
```
@env.task(cache=flyte.Cache(behavior="auto", serialize=True))
async def expensive_model_training(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
**When to use serialization**:
- Very expensive computations (model training, large data processing)
- Shared resources that shouldn't be accessed concurrently
- Operations where multiple parallel executions provide no benefit
**How it works**:
1. First execution acquires a reservation and runs normally.
2. Concurrent executions with identical inputs wait for the first to complete.
3. Once complete, all waiting executions receive the cached result.
4. If the running execution fails, another waiting execution takes over.
### Salt for cache key variation
Use `salt` to vary cache keys without changing function logic:
```
@env.task(cache=flyte.Cache(behavior="auto", salt="experiment_2024_q4"))
async def experimental_analysis(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
**`salt` is useful for**:
- A/B testing with identical code.
- Temporary cache namespaces for experiments.
- Environment-specific cache isolation.
## Cache policies
For `behavior="auto"`, Flyte uses cache policies to generate version hashes.
### Function body policy (default)
The default `FunctionBodyPolicy` generates cache versions from the function's source code:
```
from flyte._cache import FunctionBodyPolicy
@env.task(cache=flyte.Cache(
behavior="auto",
policies=[FunctionBodyPolicy()] # This is the default. Does not actually need to be specified.
))
async def code_sensitive_task(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
### Custom cache policies
You can implement custom cache policies by following the `CachePolicy` protocol:
```
from flyte._cache import CachePolicy
class DatasetVersionPolicy(CachePolicy):
def get_version(self, salt: str, params) -> str:
# Generate version based on custom logic
dataset_version = get_dataset_version()
return f"{salt}_{dataset_version}"
@env.task(cache=flyte.Cache(behavior="auto", policies=[DatasetVersionPolicy()]))
async def dataset_dependent_task(data: str) -> str:
# Cache invalidated when dataset version changes
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
## Caching configuration at different levels
You can configure caching at three levels: `TaskEnvironment` definition, `@env.task` decorator, and task invocation.
### `TaskEnvironment` Level
You can configure caching at the `TaskEnvironment` level.
This will set the default cache behavior for all tasks defined using that environment.
For example:
```
cached_env = flyte.TaskEnvironment(
name="cached_environment",
cache=flyte.Cache(behavior="auto") # Default for all tasks
)
@cached_env.task # Inherits auto caching from environment
async def inherits_caching(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
### `@env.task` decorator level
By setting the cache parameter in the `@env.task` decorator, you can override the environment's default cache behavior for specific tasks:
```
@cached_env.task(cache=flyte.Cache(behavior="disable")) # Override environment default
async def decorator_caching(data: str) -> str:
return await transform_data(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
### `task.override` level
By setting the cache parameter in the `task.override` method, you can override the cache behavior for specific task invocations:
```
@env.task
async def override_caching_on_call(data: str) -> str:
# Create an overridden version and call it
overridden_task = inherits_caching.override(cache=flyte.Cache(behavior="disable"))
return await overridden_task(data)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/caching/caching.py)
## Runtime cache control
You can also force cache invalidation for a specific run:
```python
# Disable caching for this specific execution
run = flyte.with_runcontext(overwrite_cache=True).run(my_cached_task, data="test")
```
## Project and domain cache isolation
Caches are automatically isolated by:
- **Project**: Tasks in different projects have separate cache namespaces.
- **Domain**: Development, staging, and production domains maintain separate caches.
## Local development caching
When running locally, Flyte maintains a local cache:
```python
# Local execution uses ~/.flyte/local-cache/
flyte.init() # Local mode
result = flyte.run(my_cached_task, data="test")
```
Local cache behavior:
- Stored in `~/.flyte/local-cache/` directory
- No project/domain isolation (since running locally)
- Can be cleared with `flyte local-cache clear`
- Disabled by setting `FLYTE_LOCAL_CACHE_ENABLED=false`
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-configuration/reusable-containers ===
# Reusable containers
By default, each task execution in Flyte and Union runs in a fresh container instance that is created just for that execution and then discarded.
With reusable containers, the same container can be reused across multiple executions and tasks.
This approach reduces start up overhead and improves resource efficiency.
> [!NOTE]
> The reusable container feature is only available when running your Flyte code on a Union backend.
## How It Works
With reusable containers, the system maintains a pool of persistent containers that can handle multiple task executions.
When you configure a `TaskEnvironment` with a `ReusePolicy`, the system does the following:
1. Creates a pool of persistent containers.
2. Routes task executions to available container instances.
3. Manages container lifecycle with configurable timeouts.
4. Supports concurrent task execution within containers (for async tasks).
5. Preserves the Python execution environment across task executions, allowing you to maintain state through global variables.
## Basic Usage
> [!NOTE]
> The reusable containers feature currently requires a dedicated runtime library
> ([`unionai-reuse`](https://pypi.org/project/unionai-reuse/)) to be installed in the task image used by the reusable task.
> You can add this library to your task image using the `flyte.Image.with_pip_packages` method, as shown below.
> This library only needs to be added to the task image.
> It does not need to be installed in your local development environment.
Enable container reuse by adding a `ReusePolicy` to your `TaskEnvironment`:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# ]
# main = "main"
# params = "n=500"
# ///
import flyte
from datetime import timedelta
# {{docs-fragment env}}
# Currently required to enable resuable containers
reusable_image = flyte.Image.from_debian_base().with_pip_packages("unionai-reuse>=0.1.3")
env = flyte.TaskEnvironment(
name="reusable-env",
resources=flyte.Resources(memory="1Gi", cpu="500m"),
reusable=flyte.ReusePolicy(
replicas=2, # Create 2 container instances
concurrency=1, # Process 1 task per container at a time
scaledown_ttl=timedelta(minutes=10), # Individual containers shut down after 5 minutes of inactivity
idle_ttl=timedelta(hours=1) # Entire environment shuts down after 30 minutes of no tasks
),
image=reusable_image # Use the container image augmented with the unionai-reuse library.
)
# {{/docs-fragment env}}
@env.task
async def compute_task(x: int) -> int:
return x * x
@env.task
async def main() -> list[int]:
# These tasks will reuse containers from the pool
results = []
for i in range(10):
result = await compute_task(i)
results.append(result)
return results
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/reusable-containers/reuse.py)
## `ReusePolicy` parameters
The `ReusePolicy` class controls how containers are managed in a reusable environment:
```python
flyte.ReusePolicy(
replicas: typing.Union[int, typing.Tuple[int, int]],
concurrency: int,
scaledown_ttl: typing.Union[int, datetime.timedelta],
idle_ttl: typing.Union[int, datetime.timedelta]
)
```
### `replicas`: Container pool size
Controls the number of container instances in the reusable pool:
- **Fixed size**: `replicas=3` Creates exactly 3 container instances. These 3 replicas will be shutdown after `idle_ttl` expires.
- **Auto-scaling**: `replicas=(2, 5)` Starts with 2 containers and can scale up to 5 based on demand.
- If the task is running on 2 replicas and demand drops to zero then these 2 containers will be shutdown after `idle_ttl` expires.
- If the task is running on 2 replicas and demand increases, new containers will be created up to the maximum of 5.
- If the task is running on 5 replicas and demand drops, container 5 will be shutdown after `scaledown_ttl` expires.
- If demand drops again, container 4 will be also shutdown after another period of `scaledown_ttl` expires.
- **Resource impact**: Each replica consumes the full resources defined in `TaskEnvironment.resources`.
```python
# Fixed pool size
fixed_pool_policy = flyte.ReusePolicy(
replicas=3,
concurrency=1,
scaledown_ttl=timedelta(minutes=10),
idle_ttl=timedelta(hours=1)
)
# Auto-scaling pool
auto_scaling_policy = flyte.ReusePolicy(
replicas=(1, 10),
concurrency=1,
scaledown_ttl=timedelta(minutes=10),
idle_ttl=timedelta(hours=1)
)
```
### `concurrency`: Tasks per container
Controls how many tasks can execute simultaneously within a single container:
- **Default**: `concurrency=1` (one task per container at a time).
- **Higher concurrency**: `concurrency=5` allows 5 tasks to run simultaneously in each container.
- **Total capacity**: `replicas Γ concurrency` = maximum concurrent tasks across the entire pool.
```python
# Sequential processing (default)
sequential_policy = flyte.ReusePolicy(
replicas=2,
concurrency=1, # One task per container
scaledown_ttl=timedelta(minutes=10),
idle_ttl=timedelta(hours=1)
)
# Concurrent processing
concurrent_policy = flyte.ReusePolicy(
replicas=2,
concurrency=5, # 5 tasks per container = 10 total concurrent tasks
scaledown_ttl=timedelta(minutes=10),
idle_ttl=timedelta(hours=1)
)
```
### `idle_ttl` vs `scaledown_ttl`: Container lifecycle
These parameters work together to manage container lifecycle at different levels:
#### `idle_ttl`: Environment timeout
- **Scope**: Controls the entire reusable environment infrastructure.
- **Behavior**: When there are no active or queued tasks, the entire environment scales down after `idle_ttl` expires.
- **Purpose**: Manages the lifecycle of the entire container pool.
- **Typical values**: 1-2 hours, or `None` for always-on environments
#### `scaledown_ttl`: Individual container timeout
- **Scope**: Controls individual container instances.
- **Behavior**: When a container finishes a task and becomes inactive, it will be terminated after `scaledown_ttl` expires.
- **Purpose**: Prevents resource waste from inactive containers.
- **Typical values**: 5-30 minutes for most workloads.
```python
from datetime import timedelta
lifecycle_policy = flyte.ReusePolicy(
replicas=3,
concurrency=2,
scaledown_ttl=timedelta(minutes=10), # Individual containers shut down after 10 minutes of inactivity
idle_ttl=timedelta(hours=1) # Entire environment shuts down after 1 hour of no tasks
)
```
## Understanding parameter relationships
The four `ReusePolicy` parameters work together to control different aspects of container management:
```python
reuse_policy = flyte.ReusePolicy(
replicas=4, # Infrastructure: How many containers?
concurrency=3, # Throughput: How many tasks per container?
scaledown_ttl=timedelta(minutes=10), # Individual: When do idle containers shut down?
idle_ttl=timedelta(hours=1) # Environment: When does the whole pool shut down?
)
# Total capacity: 4 Γ 3 = 12 concurrent tasks
# Individual containers shut down after 10 minutes of inactivity
# Entire environment shuts down after 1 hour of no tasks
```
### Key relationships
- **Total throughput** = `replicas Γ concurrency`
- **Resource usage** = `replicas Γ TaskEnvironment.resources`
- **Cost efficiency**: Higher `concurrency` reduces container overhead, more `replicas` provides better isolation
- **Lifecycle management**: `scaledown_ttl` manages individual containers, `idle_ttl` manages the environment
## Simple example
Here is a simple, but complete, example of reuse with concurrency
First, import the needed modules, set upf logging:
```
import asyncio
import logging
import flyte
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/reusable-containers/reuse_concurrency.py)
Next, we set up the reusable task environment. Note that, currently, the image used for a reusable environment requires an extra package to be installed:
```
env = flyte.TaskEnvironment(
name="reuse_concurrency",
resources=flyte.Resources(cpu=1, memory="1Gi"),
reusable=flyte.ReusePolicy(
replicas=2,
idle_ttl=60,
concurrency=100,
scaledown_ttl=60,
),
image=flyte.Image.from_debian_base().with_pip_packages("unionai-reuse==0.1.5b0", pre=True),
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/reusable-containers/reuse_concurrency.py)
Now, we define the `reuse_concurrency` task (the main driver task of the workflow) and the `noop` task that will be executed multiple times reusing the same containers:
```
@env.task
async def noop(x: int) -> int:
logger.debug(f"Task noop: {x}")
return x
@env.task
async def main(n: int = 50) -> int:
coros = [noop(i) for i in range(n)]
results = await asyncio.gather(*coros)
return sum(results)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/reusable-containers/reuse_concurrency.py)
Finally, we deploy and run the workflow programmatically, so all you have to do is execute `python reuse_concurrency.py` to see it in action:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, n=500)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/reusable-containers/reuse_concurrency.py)
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-configuration/pod-templates ===
# Pod templates
The `pod_template` parameter in `TaskEnvironment` (and in the @env.task decorator, if you are overriding) allows you to customize the Kubernetes pod specification that will be used to run your tasks.
This provides fine-grained control over the underlying Kubernetes resources, enabling you to configure advanced pod settings like image pull secrets, environment variables, labels, annotations, and other pod-level configurations.
## Overview
Pod templates in Flyte allow you to:
- **Configure pod metadata**: Set custom labels and annotations for your pods.
- **Specify image pull secrets**: Access private container registries.
- **Set environment variables**: Configure container-level environment variables.
- **Customize pod specifications**: Define advanced Kubernetes pod settings.
- **Control container configurations**: Specify primary container settings.
The `pod_template` parameter accepts either a string reference or a `PodTemplate` object that defines the complete pod specification.
## Basic usage
Here's a complete example showing how to use pod templates with a `TaskEnvironment`:
```
# /// script
# requires-python = "==3.12"
# dependencies = [
# "flyte==2.0.0b31",
# "kubernetes"
# ]
# ///
import flyte
from kubernetes.client import (
V1Container,
V1EnvVar,
V1LocalObjectReference,
V1PodSpec,
)
# Create a custom pod template
pod_template = flyte.PodTemplate(
primary_container_name="primary", # Name of the main container
labels={"lKeyA": "lValA"}, # Custom pod labels
annotations={"aKeyA": "aValA"}, # Custom pod annotations
pod_spec=V1PodSpec( # Kubernetes pod specification
containers=[
V1Container(
name="primary",
env=[V1EnvVar(name="hello", value="world")] # Environment variables
)
],
image_pull_secrets=[ # Access to private registries
V1LocalObjectReference(name="regcred-test")
],
),
)
# Use the pod template in a TaskEnvironment
env = flyte.TaskEnvironment(
name="hello_world",
pod_template=pod_template, # Apply the custom pod template
image=flyte.Image.from_uv_script(__file__, name="flyte", pre=True),
)
@env.task
async def say_hello(data: str) -> str:
return f"Hello {data}"
@env.task
async def say_hello_nested(data: str = "default string") -> str:
return await say_hello(data=data)
if __name__ == "__main__":
flyte.init_from_config()
result = flyte.run(say_hello_nested, data="hello world")
print(result.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/pod-templates/pod_template.py)
## PodTemplate components
The `PodTemplate` class provides the following parameters for customizing your pod configuration:
```python
pod_template = flyte.PodTemplate(
primary_container_name: str = "primary",
pod_spec: Optional[V1PodSpec] = None,
labels: Optional[Dict[str, str]] = None,
annotations: Optional[Dict[str, str]] = None
)
```
### Parameters
- **`primary_container_name`** (`str`, default: `"primary"`): Specifies the name of the main container that will run your task code. This must match the container name defined in your pod specification.
- **`pod_spec`** (`Optional[V1PodSpec]`): A standard Kubernetes `V1PodSpec` object that defines the complete pod specification. This allows you to configure any pod-level setting including containers, volumes, security contexts, node selection, and more.
- **`labels`** (`Optional[Dict[str, str]]`): Key-value pairs used for organizing and selecting pods. Labels are used by Kubernetes selectors and can be queried to filter and manage pods.
- **`annotations`** (`Optional[Dict[str, str]]`): Additional metadata attached to the pod that doesn't affect pod scheduling or selection. Annotations are typically used for storing non-identifying information like deployment revisions, contact information, or configuration details.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-configuration/multiple-environments ===
# Multiple environments
In many applications, different tasks within your workflow may require different configurations.
Flyte enables you to manage this complexity by allowing multiple environments within a single workflow.
Multiple environments are useful when:
- Different tasks in your workflow need different dependencies.
- Some tasks require specific CPU/GPU or memory configurations.
- A task requires a secret that other tasks do not (and you want to limit exposure of the secret value).
- You're integrating specialized tools that have conflicting requirements.
## Constraints on multiple environments
To use multiple environments in your workflow you define multiple `TaskEnvironment` instances, each with its own configuration, and then assign tasks to their respective environments.
There are, however, two additional constraints that you must take into account.
If `task_1` in environment `env_1` calls a `task_2` in environment `env_2`, then:
1. `env_1` must declare a deployment-time dependency on `env_2` in the `depends_on` parameter of `TaskEnvironment` that defines `env_1`.
2. The image used in the `TaskEnvironment` of `env_1` must include all dependencies of the module containing the `task_2` (unless `task_2` is invoked as a remote task).
### Task `depends_on` constraints
The `depends_on` parameter in `TaskEnvironment` is used to provide deployment-time dependencies by establishing a relationship between one `TaskEnvironment` and another.
The system uses this information to determine which environments (and, specifically which images) need to be built in order to be able to run the code.
On `flyte run` (or `flyte deploy`), the system walks the tree defined by the `depends_on` relationships, starting with the environment of the task being invoked (or the environment being deployed, in the case of `flyte deploy`), and prepares each required environment.
Most importantly, it ensures that the container images need for all required environments are available (and if not, it builds them).
This deploy-time determination of what to build is important because it means that for any given `run` or `deploy`, only those environments that are actually required are built.
The alternative strategy of building all environments defined in the set of deployed code can lead to unnecessary and expensive builds, especially when iterating on code.
### Dependency inclusion constraints
When a parent task invokes a child task in a different environment, the container image of the parent task environment must include all dependencies used by the child task.
This is necessary because of the way task invocation works in Flyte:
- When a child task is invoked by function name, that function, necessarily, has to be imported into the parent tasks's Python environment.
- This results in all the dependencies of the child task function also being imported.
- But, nonetheless, the actual execution of the child task occurs in its own environment.
To avoid this requirement, you can invoke a task in another environment _remotely_.
## Example
The following example is a (very) simple mock of an AlphaFold2 pipeline.
It demonstrates a workflow with three tasks, each in its own environment.
The example project looks like this:
```bash
βββ msa/
β βββ __init__.py
β βββ run.py
βββ fold/
β βββ __init__.py
β βββ run.py
βββ __init__.py
βββ main.py
```
(The source code for this example can be found here:[AlphaFold2 mock example](https://github.com/unionai/unionai-examples/tree/main/v2/user-guide/task-configuration/multiple-environments/af2))
In file `msa/run.py` we define the task `run_msa`, which mocks the multiple sequence alignment step of the process:
```python
import flyte
from flyte.io import File
MSA_PACKAGES = ["pytest"]
msa_image = flyte.Image.from_debian_base().with_pip_packages(*MSA_PACKAGES)
msa_env = flyte.TaskEnvironment(name="msa_env", image=msa_image)
@msa_env.task
def run_msa(x: str) -> File:
f = File.new_remote()
with f.open_sync("w") as fp:
fp.write(x)
return f
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/msa/run.py)
* A dedicated image (`msa_image`) is built using the `MSA_PACKAGES` dependency list, on top of the standard base image.
* A dedicated environment (`msa_env`) is defined for the task, using `msa_image`.
* The task is defined within the context of the `msa_env` environment.
In file `fold/run.py` we define the task `run_fold`, which mocks the fold step of the process:
```python
import flyte
from flyte.io import File
FOLD_PACKAGES = ["ruff"]
fold_image = flyte.Image.from_debian_base().with_pip_packages(*FOLD_PACKAGES)
fold_env = flyte.TaskEnvironment(name="fold_env", image=fold_image)
@fold_env.task
def run_fold(sequence: str, msa: File) -> list[str]:
with msa.open_sync("r") as f:
msa_content = f.read()
return [msa_content, sequence]
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/fold/run.py)
* A dedicated image (`fold_image`) is built using the `FOLD_PACKAGES` dependency list, on top of the standard base image.
* A dedicated environment (`fold_env`) is defined for the task, using `fold_image`.
* The task is defined within the context of the `fold_env` environment.
Finally, in file `main.py` we define the task `main` that ties everything together into a workflow.
We import the required modules and functions:
```
import logging
import pathlib
from fold.run import fold_env, fold_image, run_fold
from msa.run import msa_env, MSA_PACKAGES, run_msa
import flyte
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py)
Notice that we import
* The task functions that we will be calling: `run_fold` and `run_msa`.
* The environments of those tasks: `fold_env` and `msa_env`.
* The dependency list of the `run_msa` task: `MSA_PACKAGES`
* The image of the `run_fold` task: `fold_image`
We then assemble the image and the environment:
```
main_image = fold_image.with_pip_packages(*MSA_PACKAGES)
env = flyte.TaskEnvironment(
name="multi_env",
depends_on=[fold_env, msa_env],
image=main_image,
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py)
The image for the `main` task (`main_image`) is built by starting with `fold_image` (the image for the `run_fold` task) and adding `MSA_PACKAGES` (the dependency list for the `run_msa` task).
This ensures that `main_image` includes all dependencies needed by both the `run_fold` and `run_msa` tasks.
The environment for the `main` task is defined with:
* The image `main_image`. This ensures that the `main` task has all the dependencies it needs.
* A depends_on list that includes both `fold_env` and `msa_env`. This establishes the deploy-time dependencies on those environments.
Finally, we define the `main` task itself:
```
@env.task
def main(sequence: str) -> list[str]:
"""Given a sequence, outputs files containing the protein structure
This requires model weights + gpus + large database on aws fsx lustre
"""
print(f"Running AlphaFold2 for sequence: {sequence}")
msa = run_msa(sequence)
print(f"MSA result: {msa}, passing to fold task")
results = run_fold(sequence, msa)
print(f"Fold results: {results}")
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py)
Here we call, in turn, the `run_msa` and `run_fold` tasks.
Since we call them directly rather than as remote tasks, we had to ensure that `main_image` includes all dependencies needed by both tasks.
The final piece of the puzzle is the `if __name__ == "__main__":` block that allows us to run the `main` task on the configured Flyte backend:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main, "AAGGTTCCAA")
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/multiple-environments/af2/main.py)
Now you can run the workflow with:
```bash
python main.py
```
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-configuration/retries-and-timeouts ===
# Retries and timeouts
Flyte provides robust error handling through configurable retry strategies and timeout controls.
These parameters help ensure task reliability and prevent resource waste from runaway processes.
## Retries
The `retries` parameter controls how many times a failed task should be retried before giving up.
A "retry" is any attempt after the initial attempt.
In other words, `retries=3` means the task may be attempted up to 4 times in total (1 initial + 3 retries).
The `retries` parameter can be configured in either the `@env.task` decorator or using `override` when invoking the task.
It cannot be configured in the `TaskEnvironment` definition.
The code for the examples below can be found on [GitHub](https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py).
### Retry example
First we import the required modules and set up a task environment:
```
import random
from datetime import timedelta
import flyte
env = flyte.TaskEnvironment(name="my-env")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py)
Then we configure our task to retry up to 3 times if it fails (for a total of 4 attempts). We also define the driver task `main` that calls the `retry` task:
```
@env.task(retries=3)
async def retry() -> str:
if random.random() < 0.7: # 70% failure rate
raise Exception("Task failed!")
return "Success!"
@env.task
async def main() -> list[str]:
results = []
try:
results.append(await retry())
except Exception as e:
results.append(f"Failed: {e}")
try:
results.append(await retry.override(retries=5)())
except Exception as e:
results.append(f"Failed: {e}")
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py)
Note that we call `retry` twice: first without any `override`, and then with an `override` to increase the retries to 5 (for a total of 6 attempts).
Finally, we configure flyte and invoke the `main` task:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/retries.py)
## Timeouts
The `timeout` parameter sets limits on how long a task can run, preventing resource waste from stuck processes.
It supports multiple formats for different use cases.
The `timeout` parameter can be configured in either the `@env.task` decorator or using `override` when invoking the task.
It cannot be configured in the `TaskEnvironment` definition.
The code for the example below can be found on [GitHub](https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py).
### Timeout example
First, we import the required modules and set up a task environment:
```
import random
from datetime import timedelta
import asyncio
import flyte
from flyte import Timeout
env = flyte.TaskEnvironment(name="my-env")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
Our first task sets a timeout using seconds as an integer:
```
@env.task(timeout=60) # 60 seconds
async def timeout_seconds() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_seconds completed"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
We can also set a timeout using a `timedelta` object for more readable durations:
```
@env.task(timeout=timedelta(minutes=1))
async def timeout_timedelta() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_timedelta completed"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
You can also set separate timeouts for maximum execution time and maximum queue time using the `Timeout` class:
```
@env.task(timeout=Timeout(
max_runtime=timedelta(minutes=1), # Max execution time per attempt
max_queued_time=timedelta(minutes=1) # Max time in queue before starting
))
async def timeout_advanced() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_advanced completed"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
You can also combine retries and timeouts for resilience and resource control:
```
@env.task(
retries=3,
timeout=Timeout(
max_runtime=timedelta(minutes=1),
max_queued_time=timedelta(minutes=1)
)
)
async def timeout_with_retry() -> str:
await asyncio.sleep(random.randint(0, 120)) # Random wait between 0 and 120 seconds
return "timeout_advanced completed"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
Here we specify:
- Up to 3 retry attempts.
- Each attempt times out after 1 minute.
- Task fails if queued for more than 1 minute.
- Total possible runtime: 1 minute queue + (1 minute Γ 3 attempts).
We define the `main` driver task that calls all the timeout tasks concurrently and returns their outputs as a list. The return value for failed tasks will indicate failure:
```
@env.task
async def main() -> list[str]:
tasks = [
timeout_seconds(),
timeout_seconds.override(timeout=120)(), # Override to 120 seconds
timeout_timedelta(),
timeout_advanced(),
timeout_with_retry(),
]
results = await asyncio.gather(*tasks, return_exceptions=True)
output = []
for r in results:
if isinstance(r, Exception):
output.append(f"Failed: {r}")
else:
output.append(r)
return output
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
Note that we also demonstrate overriding the timeout for `timeout_seconds` to 120 seconds when calling it.
Finally, we configure Flyte and invoke the `main` task:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/retries-and-timeouts/timeouts.py)
Proper retry and timeout configuration ensures your Flyte workflows are both reliable and efficient, handling transient failures gracefully while preventing resource waste.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-configuration/triggers ===
# Triggers
Triggers allow you to automate and parameterize an execution by scheduling its start time and providing overrides for its task inputs.
Currently, only **schedule triggers** are supported.
This type of trigger runs a task based on a Cron expression or a fixed-rate schedule.
Support is coming for other trigger types, such as:
* Webhook triggers: Hit an API endpoint to run your task.
* Artifact triggers: Run a task when a specific artifact is produced.
## Triggers are set in the task decorator
A trigger is created by setting the `triggers` parameter in the task decorator to a `flyte.Trigger` object or a list of such objects (triggers are not settable at the `TaskEnvironment` definition or `task.override` levels).
Here is a simple example:
```
import flyte
from datetime import datetime, timezone
env = flyte.TaskEnvironment(name="trigger_env")
@env.task(triggers=flyte.Trigger.hourly()) # Every hour
def hourly_task(trigger_time: datetime, x: int = 1) -> str:
return f"Hourly example executed at {trigger_time.isoformat()} with x={x}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
Here we use a predefined schedule trigger to run the `hourly_task` every hour.
Other predefined triggers can be used similarly (see **Configure tasks > Triggers > Predefined schedule triggers** below).
If you want full control over the trigger behavior, you can define a trigger using the `flyte.Trigger` class directly.
## `flyte.Trigger`
The `Trigger` class allows you to define custom triggers with full control over scheduling and execution behavior. It has the following signature:
```
flyte.Trigger(
name,
automation,
description="",
auto_activate=True,
inputs=None,
env_vars=None,
interruptible=None,
overwrite_cache=False,
queue=None,
labels=None,
annotations=None
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Core Parameters
**`name: str`** (required)
The unique identifier for the trigger within your project/domain.
**`automation: Union[Cron, FixedRate]`** (required)
Defines when the trigger fires. Use `flyte.Cron("expression")` for Cron-based scheduling or `flyte.FixedRate(interval_minutes, start_time=start_time)` for fixed intervals.
### Configuration Parameters
**`description: str = ""`**
Human-readable description of the trigger's purpose.
**`auto_activate: bool = True`**
Whether the trigger should be automatically activated when deployed. Set to `False` to deploy inactive triggers that require manual activation.
**`inputs: Dict[str, Any] | None = None`**
Default parameter values for the task when triggered. Use `flyte.TriggerTime` as a value to inject the trigger execution timestamp into that parameter.
### Runtime Override Parameters
**`env_vars: Dict[str, str] | None = None`**
Environment variables to set for triggered executions, overriding the task's default environment variables.
**`interruptible: bool | None = None`**
Whether triggered executions can be interrupted (useful for cost optimization with spot/preemptible instances). Overrides the task's interruptible setting.
**`overwrite_cache: bool = False`**
Whether to bypass/overwrite task cache for triggered executions, ensuring fresh computation.
**`queue: str | None = None`**
Specific execution queue for triggered runs, overriding the task's default queue.
### Metadata Parameters
**`labels: Mapping[str, str] | None = None`**
Key-value labels for organizing and filtering triggers (e.g., team, component, priority).
**`annotations: Mapping[str, str] | None = None`**
Additional metadata, often used by infrastructure tools for compliance, monitoring, or cost tracking.
Here's a comprehensive example showing all parameters:
```
comprehensive_trigger = flyte.Trigger(
name="monthly_financial_report",
automation=flyte.Cron("0 6 1 * *", timezone="America/New_York"),
description="Monthly financial report generation for executive team",
auto_activate=True,
inputs={
"report_date": flyte.TriggerTime,
"report_type": "executive_summary",
"include_forecasts": True
},
env_vars={
"REPORT_OUTPUT_FORMAT": "PDF",
"EMAIL_NOTIFICATIONS": "true"
},
interruptible=False, # Critical report, use dedicated resources
overwrite_cache=True, # Always fresh data
queue="financial-reports",
labels={
"team": "finance",
"criticality": "high",
"automation": "scheduled"
},
annotations={
"compliance.company.com/sox-required": "true",
"backup.company.com/retain-days": "2555" # 7 years
}
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
## The `automation` parameter with `flyte.FixedRate`
You can define a fixed-rate schedule trigger by setting the `automation` parameter of the `flyte.Trigger` to an instance of `flyte.FixedRate`.
The `flyte.FixedRate` has the following signature:
```
flyte.FixedRate(
interval_minutes,
start_time=None
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Parameters
**`interval_minutes: int`** (required)
The interval between trigger executions in minutes.
**`start_time: datetime | None`**
When to start the fixed rate schedule. If not specified, starts when the trigger is deployed and activated.
### Examples
```
# Every 90 minutes, starting when deployed
every_90_min = flyte.Trigger(
"data_processing",
flyte.FixedRate(interval_minutes=90)
)
# Every 6 hours (360 minutes), starting at a specific time
specific_start = flyte.Trigger(
"batch_job",
flyte.FixedRate(
interval_minutes=360, # 6 hours
start_time=datetime(2025, 12, 1, 9, 0, 0) # Start Dec 1st at 9 AM
)
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
## The `automation` parameter with `flyte.Cron`
You can define a Cron-based schedule trigger by setting the `automation` parameter to an instance of `flyte.Cron`.
The `flyte.Cron` has the following signature:
```
flyte.Cron(
cron_expression,
timezone=None
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Parameters
**`cron_expression: str`** (required)
The cron expression defining when the trigger should fire. Uses standard Unix cron format with five fields: minute, hour, day of month, month, and day of week.
**`timezone: str | None`**
The timezone for the cron expression. If not specified, it defaults to UTC. Uses standard timezone names like "America/New_York" or "Europe/London".
### Examples
```
# Every day at 6 AM UTC
daily_trigger = flyte.Trigger(
"daily_report",
flyte.Cron("0 6 * * *")
)
# Every weekday at 9:30 AM Eastern Time
weekday_trigger = flyte.Trigger(
"business_hours_task",
flyte.Cron("30 9 * * 1-5", timezone="America/New_York")
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
#### Cron Expressions
Here are some common cron expressions you can use:
| Expression | Description |
|----------------|--------------------------------------|
| `0 0 * * *` | Every day at midnight |
| `0 9 * * 1-5` | Every weekday at 9 AM |
| `30 14 * * 6` | Every Saturday at 2:30 PM |
| `0 0 1 * *` | First day of every month at midnight |
| `0 0 25 * *` | 25th day of every month at midnight |
| `0 0 * * 0` | Every Sunday at midnight |
| `*/10 * * * *` | Every 10 minutes |
| `0 */2 * * *` | Every 2 hours |
For a full guide on Cron syntax, refer to [Crontab Guru](https://crontab.guru/).
## The `inputs` parameter
The `inputs` parameter allows you to provide default values for your task's parameters when the trigger fires.
This is essential for parameterizing your automated executions and passing trigger-specific data to your tasks.
### Basic Usage
```
trigger_with_inputs = flyte.Trigger(
"data_processing",
flyte.Cron("0 6 * * *"), # Daily at 6 AM
inputs={
"batch_size": 1000,
"environment": "production",
"debug_mode": False
}
)
@env.task(triggers=trigger_with_inputs)
def process_data(batch_size: int, environment: str, debug_mode: bool = True) -> str:
return f"Processing {batch_size} items in {environment} mode"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Using `flyte.TriggerTime`
The special `flyte.TriggerTime` value is used in the `inputs` to indicate the task parameter into which Flyte will inject the trigger execution timestamp:
```
timestamp_trigger = flyte.Trigger(
"daily_report",
flyte.Cron("0 0 * * *"), # Daily at midnight
inputs={
"report_date": flyte.TriggerTime, # Receives trigger execution time
"report_type": "daily_summary"
}
)
@env.task(triggers=timestamp_trigger)
def generate_report(report_date: datetime, report_type: str) -> str:
return f"Generated {report_type} for {report_date.strftime('%Y-%m-%d')}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Required vs optional parameters
> [!IMPORTANT]
> If your task has parameters without default values, you **must** provide values for them in the trigger inputs, otherwise the trigger will fail to execute.
```python
# β This will fail - missing required parameter 'data_source'
bad_trigger = flyte.Trigger(
"bad_trigger",
flyte.Cron("0 0 * * *")
# Missing inputs for required parameter 'data_source'
)
@env.task(triggers=bad_trigger)
def bad_trigger_taska(data_source: str, batch_size: int = 100) -> str:
return f"Processing from {data_source} with batch size {batch_size}"
# β This works - all required parameters provided
good_trigger = flyte.Trigger(
"good_trigger",
flyte.Cron("0 0 * * *"),
inputs={
"data_source": "prod_database", # Required parameter
"batch_size": 500 # Override default
}
)
@env.task(triggers=good_trigger)
def good_trigger_task(data_source: str, batch_size: int = 100) -> str:
return f"Processing from {data_source} with batch size {batch_size}"
```
### Complex input types
You can pass various data types through trigger inputs:
```
complex_trigger = flyte.Trigger(
"ml_training",
flyte.Cron("0 2 * * 1"), # Weekly on Monday at 2 AM
inputs={
"model_config": {
"learning_rate": 0.01,
"batch_size": 32,
"epochs": 100
},
"feature_columns": ["age", "income", "location"],
"validation_split": 0.2,
"training_date": flyte.TriggerTime
}
)
@env.task(triggers=complex_trigger)
def train_model(
model_config: dict,
feature_columns: list[str],
validation_split: float,
training_date: datetime
) -> str:
return f"Training model with {len(feature_columns)} features on {training_date}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
## Predefined schedule triggers
For common scheduling needs, Flyte provides predefined trigger methods that create Cron-based schedules without requiring you to specify cron expressions manually.
These are convenient shortcuts for frequently used scheduling patterns.
### Available Predefined Triggers
```
minutely_trigger = flyte.Trigger.minutely() # Every minute
hourly_trigger = flyte.Trigger.hourly() # Every hour
daily_trigger = flyte.Trigger.daily() # Every day at midnight
weekly_trigger = flyte.Trigger.weekly() # Every week (Sundays at midnight)
monthly_trigger = flyte.Trigger.monthly() # Every month (1st day at midnight)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
For reference, here's what each predefined trigger is equivalent to:
```python
# These are functionally identical:
flyte.Trigger.minutely() == flyte.Trigger("minutely", flyte.Cron("* * * * *"))
flyte.Trigger.hourly() == flyte.Trigger("hourly", flyte.Cron("0 * * * *"))
flyte.Trigger.daily() == flyte.Trigger("daily", flyte.Cron("0 0 * * *"))
flyte.Trigger.weekly() == flyte.Trigger("weekly", flyte.Cron("0 0 * * 0"))
flyte.Trigger.monthly() == flyte.Trigger("monthly", flyte.Cron("0 0 1 * *"))
```
### Predefined Trigger Parameters
All predefined trigger methods (`minutely()`, `hourly()`, `daily()`, `weekly()`, `monthly()`) accept the same set of parameters:
```
flyte.Trigger.daily(
trigger_time_input_key="trigger_time",
name="daily",
description="A trigger that runs daily at midnight",
auto_activate=True,
inputs=None,
env_vars=None,
interruptible=None,
overwrite_cache=False,
queue=None,
labels=None,
annotations=None
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
#### Core Parameters
**`trigger_time_input_key: str = "trigger_time"`**
The name of the task parameter that will receive the execution timestamp.
If no `trigger_time_input_key` is provided, the default is `trigger_time`.
In this case, if the task does not have a parameter named `trigger_time`, the task will still be executed, but, obviously, the timestamp will not be passed.
However, if you do specify a `trigger_time_input_key`, but your task does not actually have the specified parameter, an error will be raised at trigger deployment time.
**`name: str`**
The unique identifier for the trigger. Defaults to the method name (`"daily"`, `"hourly"`, etc.).
**`description: str`**
Human-readable description of the trigger's purpose. Each method has a sensible default.
#### Configuration Parameters
**`auto_activate: bool = True`**
Whether the trigger should be automatically activated when deployed. Set to `False` to deploy inactive triggers that require manual activation.
**`inputs: Dict[str, Any] | None = None`**
Additional parameter values for your task when triggered. The `trigger_time_input_key` parameter is automatically included with `flyte.TriggerTime` as its value.
#### Runtime Override Parameters
**`env_vars: Dict[str, str] | None = None`**
Environment variables to set for triggered executions, overriding the task's default environment variables.
**`interruptible: bool | None = None`**
Whether triggered executions can be interrupted (useful for cost optimization with spot/preemptible instances). Overrides the task's interruptible setting.
**`overwrite_cache: bool = False`**
Whether to bypass/overwrite task cache for triggered executions, ensuring fresh computation.
**`queue: str | None = None`**
Specific execution queue for triggered runs, overriding the task's default queue.
#### Metadata Parameters
**`labels: Mapping[str, str] | None = None`**
Key-value labels for organizing and filtering triggers (e.g., team, component, priority).
**`annotations: Mapping[str, str] | None = None`**
Additional metadata, often used by infrastructure tools for compliance, monitoring, or cost tracking.
### Trigger time in predefined triggers
By default, predefined triggers will pass the execution time to the parameter `trigger_time` of type `datetime`,if that parameter exists on the task.
If no such parameter exists, the task will still be executed without error.
Optionally, you can customize the parameter name that receives the trigger execution timestamp by setting the `trigger_time_input_key` parameter (in this case the absence of this custom parameter on the task will raise an error at trigger deployment time):
```
@env.task(triggers=flyte.Trigger.daily(trigger_time_input_key="scheduled_at"))
def task_with_custom_trigger_time_input(scheduled_at: datetime) -> str:
return f"Executed at {scheduled_at}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
## Multiple triggers per task
You can attach multiple triggers to a single task by providing a list of triggers. This allows you to run the same task on different schedules or with different configurations:
```
@env.task(triggers=[
flyte.Trigger.hourly(), # Predefined trigger
flyte.Trigger.daily(), # Another predefined trigger
flyte.Trigger("custom", flyte.Cron("0 */6 * * *")) # Custom trigger every 6 hours
])
def multi_trigger_task(trigger_time: datetime = flyte.TriggerTime) -> str:
# Different logic based on execution timing
if trigger_time.hour == 0: # Daily run at midnight
return f"Daily comprehensive processing at {trigger_time}"
else: # Hourly or custom runs
return f"Regular processing at {trigger_time.strftime('%H:%M')}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
You can mix and match trigger types, combining predefined triggers with those that use `flyte.Cron`, and `flyte.FixedRate` automations (see below for explanations of these concepts).
## Deploying a task with triggers
We recommend that you define your triggers in code together with your tasks and deploy them together.
The Union UI displays:
* `Owner` - who last deployed the trigger.
* `Last updated` - who last activated or deactivated the trigger and when. Note: If you deploy a trigger with `auto_activate=True`(default), this will match the `Owner`.
* `Last Run` - when was the last run created by this trigger.
For development and debugging purposes, you can adjust and deploy individual triggers from the UI.
To deploy a task with its triggers, you can either use Flyte CLI:
```shell
flyte deploy -p -d env
```
Or in Python:
```python
flyte.deploy(env)
```
Upon deploy, all triggers that are associated with a given task `T` will be automatically switched to apply to the latest version of that task. Triggers on task `T` which are defined elsewhere (i.e. in the UI) will be deleted unless they have been referenced in the task definition of `T`
## Activating and deactivating triggers
By default, triggers are automatically activated upon deployment (`auto_activate=True`).
Alternatively, you can set `auto_activate=False` to deploy inactive triggers.
An inactive trigger will not create runs until activated.
```
env = flyte.TaskEnvironment(name="my_task_env")
custom_cron_trigger = flyte.Trigger(
"custom_cron",
flyte.Cron("0 0 * * *"),
auto_activate=False # Dont create runs yet
)
@env.task(triggers=custom_cron_trigger)
def custom_task() -> str:
return "Hello, world!"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
This trigger won't create runs until it is explicitly activated.
You can activate a trigger via the Flyte CLI:
```shell
flyte update trigger custom_cron my_task_env.custom_task --activate --project --domain
```
If you want to stop your trigger from creating new runs, you can deactivate it:
```shell
flyte update trigger custom_cron my_task_env.custom_task --deactivate --project --domain
```
You can also view and manage your deployed triggers in the Union UI.
## Trigger run timing
The timing of the first run created by a trigger depends on the type of trigger used (Cron-based or Fixed-rate) and whether the trigger is active upon deployment.
### Cron-based triggers
For Cron-based triggers, the first run will be created at the next scheduled time according to the cron expression after trigger activation and similarly thereafter.
* `0 0 * * *` If deployed at 17:00 today, the trigger will first fire 7 hours later (0:00 of the following day) and then every day at 0:00 thereafter.
* `*/15 14 * * 1-5` if today is Tuesday at 17:00, the trigger will fire the next day (Wednesday) at 14:00, 14:15, 14:30, and 14:45 and then the same for every subsequent weekday thereafter.
### Fixed-rate triggers without `start_time`
If no `start_time` is specified, then the first run will be created after the specified interval from the time of activation. No run will be created immediately upon activation, but the activation time will be used as the reference point for future runs.
#### No `start_time`, auto_activate: True
Let's say you define a fixed rate trigger with automatic activation like this:
```
my_trigger = flyte.Trigger("my_trigger", flyte.FixedRate(60))
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
In this case, the first run will occur 60 minutes after the successful deployment of the trigger.
So, if you deployed this trigger at 13:15, the first run will occur at 14:15 and so on thereafter.
#### No `start_time`, auto_activate: False
On the other hand, let's say you define a fixed rate trigger without automatic activation like this:
```
my_trigger = flyte.Trigger("my_trigger", flyte.FixedRate(60), auto_activate=False)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
Then you activate it after about 3 hours. In this case the first run will kick off 60 minutes after trigger activation.
If you deployed the trigger at 13:15 and activated it at 16:07, the first run will occur at 17:07.
### Fixed-rate triggers with `start_time`
If a `start_time` is specified, the timing of the first run depends on whether the trigger is active at `start_time` or not.
#### Fixed-rate with `start_time` while active
If a `start_time` is specified, and the trigger is active at `start_time` then the first run will occur at `start_time` and then at the specified interval thereafter.
For example:
```
my_trigger = flyte.Trigger(
"my_trigger",
# Runs every 60 minutes starting from October 26th, 2025, 10:00am
flyte.FixedRate(60, start_time=datetime(2025, 10, 26, 10, 0, 0)),
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
If you deploy this trigger on October 24th, 2025, the trigger will wait until October 26th 10:00am and will create the first run at exactly 10:00am.
#### Fixed-rate with `start_time` while inactive
If a start time is specified, but the trigger is activated after `start_time`, then the first run will be created when the next time point occurs that aligns with the recurring trigger interval using `start_time` as the initial reference point.
For example:
```
custom_rate_trigger = flyte.Trigger(
"custom_rate",
# Runs every 60 minutes starting from October 26th, 2025, 10:00am
flyte.FixedRate(60, start_time=datetime(2025, 10, 26, 10, 0, 0)),
auto_activate=False
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
If activated later than the `start_time`, say on October 28th 12:35pm for example, the first run will be created at October 28th at 1:00pm.
## Deleting triggers
If you decide that you don't need a trigger anymore, you can remove the trigger from the task definition and deploy the task again.
Alternatively, you can use Flyte CLI:
```shell
flyte delete trigger custom_cron my_task_env.custom_task --project --domain
```
## Schedule time zones
### Setting time zone for a Cron schedule
Cron expressions are by default in UTC, but it's possible to specify custom time zones like so:
```
sf_trigger = flyte.Trigger(
"sf_tz",
flyte.Cron(
"0 9 * * *", timezone="America/Los_Angeles"
), # Every day at 9 AM PT
inputs={"start_time": flyte.TriggerTime, "x": 1},
)
nyc_trigger = flyte.Trigger(
"nyc_tz",
flyte.Cron(
"1 12 * * *", timezone="America/New_York"
), # Every day at 12:01 PM ET
inputs={"start_time": flyte.TriggerTime, "x": 1},
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
The above two schedules will fire 1 minute apart, at 9 AM PT and 12:01 PM ET respectively.
### `flyte.TriggerTime` is always in UTC
The `flyte.TriggerTime` value is always in UTC. For timezone-aware logic, convert as needed:
```
@env.task(triggers=flyte.Trigger.minutely(trigger_time_input_key="utc_trigger_time", name="timezone_trigger"))
def timezone_task(utc_trigger_time: datetime) -> str:
local_time = utc_trigger_time.replace(tzinfo=timezone.utc).astimezone()
return f"Task fired at {utc_trigger_time} UTC ({local_time} local)"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-configuration/triggers/triggers.py)
### Daylight Savings Time behavior
When Daylight Savings Time (DST) begins and ends, it can impact when the scheduled execution begins.
On the day DST begins, time jumps from 2:00AM to 3:00AM, which means the time of 2:30AM won't exist. In this case, the trigger will not fire until the next 2:30AM, which is the next day.
On the day DST ends, the hour from 1:00AM to 2:00AM repeats, which means the time of 1:30AM will exist twice. If the schedule above was instead set for 1:30AM, it would only run once, on the first occurrence of 1:30AM.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-programming ===
# Build tasks
This section covers the essential programming patterns and techniques for developing robust Flyte workflows. Once you understand the basics of task configuration, these guides will help you build sophisticated, production-ready data pipelines and machine learning workflows.
## What you'll learn
The task programming section covers key patterns for building effective Flyte workflows:
**Data handling and types**
- **Build tasks > Files and directories**: Work with large datasets using Flyte's efficient file and directory types that automatically handle data upload, storage, and transfer between tasks.
- **Build tasks > Dataclasses and structures**: Use Python dataclasses and Pydantic models as task inputs and outputs to create well-structured, type-safe workflows.
- **Build tasks > Custom context**: Use custom context to pass metadata through your task execution hierarchy without adding parameters to every task.
**Execution patterns**
- **Build tasks > Fanout**: Scale your workflows by running many tasks in parallel, perfect for processing large datasets or running hyperparameter sweeps.
- **Build tasks > Grouping actions**: Organize related task executions into logical groups for better visualization and management in the UI.
**Development and debugging**
- **Build tasks > Notebooks**: Write and iterate on workflows directly in Jupyter notebooks for interactive development and experimentation.
- **Build tasks > Reports**: Generate custom HTML reports during task execution to display progress, results, and visualizations in the UI.
- **Build tasks > Traces**: Add fine-grained observability to helper functions within your tasks for better debugging and resumption capabilities.
- **Build tasks > Error handling**: Implement robust error recovery strategies, including automatic resource scaling and graceful failure handling.
## When to use these patterns
These programming patterns become essential as your workflows grow in complexity:
- Use **fanout** when you need to process multiple items concurrently or run parameter sweeps.
- Implement **error handling** for production workflows that need to recover from infrastructure failures.
- Apply **grouping** to organize complex workflows with many task executions.
- Leverage **files and directories** when working with large datasets that don't fit in memory.
- Use **traces** to debug non-deterministic operations like API calls or ML inference.
- Create **reports** to monitor long-running workflows and share results with stakeholders.
- Use **custom context** when you need lightweight, cross-cutting metadata to flow through your task hierarchy without becoming part of the taskβs logical inputs.
Each guide includes practical examples and best practices to help you implement these patterns effectively in your own workflows.
## Subpages
- **Build tasks > Dataclasses and structures**
- **Build tasks > DataFrames**
- **Build tasks > Files and directories**
- **Build tasks > Custom context**
- **Build tasks > Reports**
- **Build tasks > Notebooks**
- **Build tasks > Error handling**
- **Build tasks > Traces**
- **Build tasks > Grouping actions**
- **Build tasks > Fanout**
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-programming/dataclasses-and-structures ===
# Dataclasses and structures
Dataclasses and Pydantic models are fully supported in Flyte as **materialized data types**:
Structured data where the full content is serialized and passed between tasks.
Use these as you would normally, passing them as inputs and outputs of tasks.
Unlike **offloaded types** like **Build tasks > DataFrames**, **Build tasks > Files and directories**, dataclass and Pydantic model data is fully serialized, stored, and deserialized between tasks.
This makes them ideal for configuration objects, metadata, and smaller structured data where all fields should be serializable.
## Example: Combining Dataclasses and Pydantic Models
This example demonstrates how dataclasses and Pydantic models work together as materialized data types, showing nested structures and batch processing patterns:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from dataclasses import dataclass
from typing import List
from pydantic import BaseModel
import flyte
env = flyte.TaskEnvironment(name="ex-mixed-structures")
@dataclass
class InferenceRequest:
feature_a: float
feature_b: float
@dataclass
class BatchRequest:
requests: List[InferenceRequest]
batch_id: str = "default"
class PredictionSummary(BaseModel):
predictions: List[float]
average: float
count: int
batch_id: str
@env.task
async def predict_one(request: InferenceRequest) -> float:
"""
A dummy linear model: prediction = 2 * feature_a + 3 * feature_b + bias(=1.0)
"""
return 2.0 * request.feature_a + 3.0 * request.feature_b + 1.0
@env.task
async def process_batch(batch: BatchRequest) -> PredictionSummary:
"""
Processes a batch of inference requests and returns summary statistics.
"""
# Process all requests concurrently
tasks = [predict_one(request=req) for req in batch.requests]
predictions = await asyncio.gather(*tasks)
# Calculate statistics
average = sum(predictions) / len(predictions) if predictions else 0.0
return PredictionSummary(
predictions=predictions,
average=average,
count=len(predictions),
batch_id=batch.batch_id
)
@env.task
async def summarize_results(summary: PredictionSummary) -> str:
"""
Creates a text summary from the prediction results.
"""
return (
f"Batch {summary.batch_id}: "
f"Processed {summary.count} predictions, "
f"average value: {summary.average:.2f}"
)
@env.task
async def main() -> str:
batch = BatchRequest(
requests=[
InferenceRequest(feature_a=1.0, feature_b=2.0),
InferenceRequest(feature_a=3.0, feature_b=4.0),
InferenceRequest(feature_a=5.0, feature_b=6.0),
],
batch_id="demo_batch_001"
)
summary = await process_batch(batch)
result = await summarize_results(summary)
return result
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataclasses-and-structures/example.py)
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-programming/dataframes ===
# DataFrames
By default, return values in Python are materialized - meaning the actual data is downloaded and loaded into memory. This applies to simple types like integers, as well as more complex types like DataFrames.
To avoid downloading large datasets into memory, Flyte V2 exposes **Build tasks > DataFrames > `flyte.io.dataframe`**: a thin, uniform wrapper type for DataFrame-style objects that allows you to pass a reference to the data, rather than the fully materialized contents.
The `flyte.io.DataFrame` type provides serialization support for common engines like `pandas`, `polars`, `pyarrow`, `dask`, etc.; enabling you to move data between different DataFrame backends.
## Setting up the environment and sample data
For our example we will start by setting up our task environment with the required dependencies and create some sample data.
```
from typing import Annotated
import numpy as np
import pandas as pd
import flyte
import flyte.io
env = flyte.TaskEnvironment(
"dataframe_usage",
image= flyte.Image.from_debian_base().with_pip_packages("pandas", "pyarrow", "numpy"),
resources=flyte.Resources(cpu="1", memory="2Gi"),
)
BASIC_EMPLOYEE_DATA = {
"employee_id": range(1001, 1009),
"name": ["Alice", "Bob", "Charlie", "Diana", "Ethan", "Fiona", "George", "Hannah"],
"department": ["HR", "Engineering", "Engineering", "Marketing", "Finance", "Finance", "HR", "Engineering"],
"hire_date": pd.to_datetime(
["2018-01-15", "2019-03-22", "2020-07-10", "2017-11-01", "2021-06-05", "2018-09-13", "2022-01-07", "2020-12-30"]
),
}
ADDL_EMPLOYEE_DATA = {
"employee_id": range(1001, 1009),
"salary": [55000, 75000, 72000, 50000, 68000, 70000, np.nan, 80000],
"bonus_pct": [0.05, 0.10, 0.07, 0.04, np.nan, 0.08, 0.03, 0.09],
"full_time": [True, True, True, False, True, True, False, True],
"projects": [
["Recruiting", "Onboarding"],
["Platform", "API"],
["API", "Data Pipeline"],
["SEO", "Ads"],
["Budget", "Forecasting"],
["Auditing"],
[],
["Platform", "Security", "Data Pipeline"],
],
}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
## Create a raw dataframe
Now, let's create a task that returns a native Pandas DataFrame:
```
@env.task
async def create_raw_dataframe() -> pd.DataFrame:
return pd.DataFrame(BASIC_EMPLOYEE_DATA)
# {{docs-fragment from-df}}
@env.task
async def create_flyte_dataframe() -> Annotated[flyte.io.DataFrame, "parquet"]:
pd_df = pd.DataFrame(ADDL_EMPLOYEE_DATA)
fdf = flyte.io.DataFrame.from_df(pd_df)
return fdf
# {{/docs-fragment from-df}}
# {{docs-fragment automatic}}
@env.task
async def join_data(raw_dataframe: pd.DataFrame, flyte_dataframe: pd.DataFrame) -> flyte.io.DataFrame:
joined_df = raw_dataframe.merge(flyte_dataframe, on="employee_id", how="inner")
return flyte.io.DataFrame.from_df(joined_df)
# {{/docs-fragment automatic}}
# {{docs-fragment download}}
@env.task
async def download_data(joined_df: flyte.io.DataFrame):
downloaded = await joined_df.open(pd.DataFrame).all()
print("Downloaded Data:\n", downloaded)
# {{/docs-fragment download}}
# {{docs-fragment main}}
@env.task
async def main():
raw_df = await create_raw_dataframe ()
flyte_df = await create_flyte_dataframe ()
joined_df = await join_data (raw_df, flyte_df)
await download_data (joined_df)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
This is the most basic use-case of how to pass DataFrames (of all kinds, not just Pandas).
We simply create the DataFrame as normal, and return it.
Because the task has been declared to return a supported native DataFrame type (in this case `pandas.DataFrame` Flyte will automatically detect it, serialize it correctly and upload it at task completion enabling it to be passed transparently to the next task.
Flyte supports auto-serialization for the following DataFrame types:
* `pandas.DataFrame`
* `pyarrow.Table`
* `dask.dataframe.DataFrame`
* `polars.DataFrame`
* `flyte.io.DataFrame` (see below)
## Create a flyte.io.DataFrame
Alternatively you can also create a `flyte.io.DataFrame` object directly from a native object with the `from_df` method:
```
@env.task
async def create_flyte_dataframe() -> Annotated[flyte.io.DataFrame, "parquet"]:
pd_df = pd.DataFrame(ADDL_EMPLOYEE_DATA)
fdf = flyte.io.DataFrame.from_df(pd_df)
return fdf
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
The `flyte.io.DataFrame` class creates a thin wrapper around objects of any standard DataFrame type. It serves as a generic "any dataframe type" (a concept that Python itself does not cxurrently offer).
As with native DataFrame types, Flyte will automatically serialize and upload the data at task completion.
The advantage of the unified `flyte.io.DataFrame` wrapper is that you can be explicit about the storage format that makes sense for your use case, by using an `Annotated` type where the second argument encodes format or other lightweight hints. For example, here we specify that the DataFrame should be stored as Parquet:
## Automatically convert between types
You can leverage Flyte to automatically download and convert the dataframe between types when needed:
```
@env.task
async def join_data(raw_dataframe: pd.DataFrame, flyte_dataframe: pd.DataFrame) -> flyte.io.DataFrame:
joined_df = raw_dataframe.merge(flyte_dataframe, on="employee_id", how="inner")
return flyte.io.DataFrame.from_df(joined_df)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
This task takes two dataframes as input. We'll pass one raw Pandas dataframe, and one `flyte.io.DataFrame`.
Flyte automatically converts the `flyte.io.DataFrame` to a Pandas DataFrame (since we declared that as the input type) before passing it to the task.
The actual download and conversion happens only when we access the data, in this case, when we do the merge.
## Downloading DataFrames
When a task receives a `flyte.io.DataFrame`, you can request a concrete backend representation. For example, to download as a pandas DataFrame:
```
@env.task
async def download_data(joined_df: flyte.io.DataFrame):
downloaded = await joined_df.open(pd.DataFrame).all()
print("Downloaded Data:\n", downloaded)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
The `open()` call delegates to the DataFrame handler for the stored format and converts to the requested in-memory type.
## Run the example
Finally, we can define a `main` function to run the tasks defined above and a `__main__` block to execute the workflow:
```
@env.task
async def main():
raw_df = await create_raw_dataframe ()
flyte_df = await create_flyte_dataframe ()
joined_df = await join_data (raw_df, flyte_df)
await download_data (joined_df)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/dataframes/dataframes.py)
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-programming/files-and-directories ===
# Files and directories
Flyte provides the **Build tasks > Files and directories > `flyte.io.File`** and
**Build tasks > Files and directories > `flyte.io.Dir`** types to represent files and directories, respectively.
Together with **Build tasks > DataFrames** they constitute the *offloaded data types* - unlike **Build tasks > Dataclasses and structures** like dataclasses, these pass references rather than full data content.
A variable of an offloaded type does not contain its actual data, but rather a reference to the data.
The actual data is stored in the internal blob store of your Union/Flyte instance.
When a variable of an offloaded type is first created, its data is uploaded to the blob store.
It can then be passed from task to task as a reference.
The actual data is only downloaded from the blob stored when the task needs to access it, for example, when the task calls `open()` on a `File` or `Dir` object.
This allows Flyte to efficiently handle large files and directories without needing to transfer the data unnecessarily.
Even very large data objects like video files and DNA datasets can be passed efficiently between tasks.
The `File` and `Dir` classes provide both `sync` and `async` methods to interact with the data.
## Example usage
The examples below show the basic use-cases of uploading files and directories created locally, and using them as inputs to a task.
```
import asyncio
import tempfile
from pathlib import Path
import flyte
from flyte.io import Dir, File
env = flyte.TaskEnvironment(name="files-and-folders")
@env.task
async def write_file(name: str) -> File:
# Create a file and write some content to it
with open("test.txt", "w") as f:
f.write(f"hello world {name}")
# Upload the file using flyte
uploaded_file_obj = await File.from_local("test.txt")
return uploaded_file_obj
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/file_and_dir.py)
The upload happens when the **Build tasks > Files and directories > `File.from_local`** command is called.
Because the upload would otherwise block execution, `File.from_local` is implemented as an `async` function.
The Flyte SDK frequently uses this class constructor pattern, so you will see it with other types as well.
This is a slightly more complicated task that calls the task above to produce `File` objects.
These are assembled into a directory and the `Dir` object is returned, also via invoking `from_local`.
```
@env.task
async def write_and_check_files() -> Dir:
coros = []
for name in ["Alice", "Bob", "Eve"]:
coros.append(write_file(name=name))
vals = await asyncio.gather(*coros)
temp_dir = tempfile.mkdtemp()
for file in vals:
async with file.open("rb") as fh:
contents = await fh.read()
# Convert bytes to string
contents_str = contents.decode('utf-8') if isinstance(contents, bytes) else str(contents)
print(f"File {file.path} contents: {contents_str}")
new_file = Path(temp_dir) / file.name
with open(new_file, "w") as out: # noqa: ASYNC230
out.write(contents_str)
print(f"Files written to {temp_dir}")
# walk the directory and ls
for path in Path(temp_dir).iterdir():
print(f"File: {path.name}")
my_dir = await Dir.from_local(temp_dir)
return my_dir
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/file_and_dir.py)
Finally, these tasks show how to use an offloaded type as an input.
Helper functions like `walk` and `open` have been added to the objects
and do what you might expect.
```
@env.task
async def check_dir(my_dir: Dir):
print(f"Dir {my_dir.path} contents:")
async for file in my_dir.walk():
print(f"File: {file.name}")
async with file.open("rb") as fh:
contents = await fh.read()
# Convert bytes to string
contents_str = contents.decode('utf-8') if isinstance(contents, bytes) else str(contents)
print(f"Contents: {contents_str}")
@env.task
async def create_and_check_dir():
my_dir = await write_and_check_files()
await check_dir(my_dir=my_dir)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(create_and_check_dir)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/files-and-directories/file_and_dir.py)
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-programming/custom-context ===
# Custom context
Custom context provides a mechanism for implicitly passing configuration and metadata through your entire task execution hierarchy without adding parameters to every task. It is ideal for cross-cutting concerns such as tracing, environment metadata, or experiment identifiers.
Think of custom context as **execution-scoped metadata** that automatically flows from parent to child tasks.
## Overview
Custom context is an implicit keyβvalue configuration map that is automatically available to tasks during execution. It is stored in the blob store of your Union/Flyte instance together with the taskβs inputs, making it available across tasks without needing to pass it explicitly.
You can access it in a Flyte task via:
```python
flyte.ctx().custom_context
```
Custom context is fundamentally different from standard task inputs. Task inputs are explicit, strongly typed parameters that you declare as part of a taskβs signature. They directly influence the taskβs computation and therefore participate in Flyteβs caching and reproducibility guarantees.
Custom context, on the other hand, is implicit metadata. It consists only of string key/value pairs, is not part of the task signature, and does not affect task caching. Because it is injected by the Flyte runtime rather than passed as a formal input, it should be used only for environmental or contextual information, not for data that changes the logical output of a task.
## When to use it and when not to
Custom Context is perfect when you need metadata, not domain data, to flow through your tasks.
Good use cases:
- Tracing IDs, span IDs
- Experiment or run metadata
- Environment region, cluster ID
- Logging correlation keys
- Feature flags
- Session IDs for 3rd-party APIs (e.g., an LLM session)
Avoid using for:
- Business/domain data
- Inputs that change task outputs
- Anything affecting caching or reproducibility
- Large blobs of data (keep it small)
It is the cleanest mechanism when you need something available everywhere, but not logically an input to the computation.
## Setting custom context
There are two ways to set custom context for a Flyte run:
1. Set it once for the entire run when you launch (`with_runcontext`) β this establishes the base context for the execution
2. Set or override it inside task code using `flyte.custom_context(...)` context manager β this changes the active context for that task block and any nested tasks called from it
Both are legitimate and complementary. The important behavioral rules to understand are:
- `with_runcontext(...)` sets the run-level base. Values provided here are available everywhere unless overridden later. Use this for metadata that should apply to most or all tasks in the run (experiment name, top-level trace id, run id, etc.).
- `flyte.custom_context(...)` is used inside task code to set or override values for that scope. It does affect nested tasks invoked while that context is active. In practice this means you can override run-level entries, add new keys for downstream tasks, or both.
- Merging & precedence: contexts are merged; when the same key appears in multiple places the most recent/innermost value wins (i.e., values set by `flyte.custom_context(...)` override the run-level values from `with_runcontext(...)` for the duration of that block).
### Run-level context
Set base metadata once when starting the run:
```
import flyte
env = flyte.TaskEnvironment("custom-context-example")
@env.task
async def leaf_task() -> str:
# Reads run-level context
print("leaf sees:", flyte.ctx().custom_context)
return flyte.ctx().custom_context.get("trace_id")
@env.task
async def root() -> str:
return await leaf_task()
if __name__ == "__main__":
flyte.init_from_config()
# Base context for the entire run
flyte.with_runcontext(custom_context={"trace_id": "root-abc", "experiment": "v1"}).run(root)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/custom-context/run_context.py)
Output (every task sees the base keys unless overridden):
```bash
leaf sees: {"trace_id": "root-abc", "experiment": "v1"}
```
### Overriding inside a task (local override that affects nested tasks)
Use `flyte.custom_context(...)` inside a task to override or add keys for downstream calls:
```
@env.task
async def downstream() -> str:
print("downstream sees:", flyte.ctx().custom_context)
return flyte.ctx().custom_context.get("trace_id")
@env.task
async def parent() -> str:
print("parent initial:", flyte.ctx().custom_context)
# Override the trace_id for the nested call(s)
with flyte.custom_context(trace_id="child-override"):
val = await downstream() # downstream sees trace_id="child-override"
# After the context block, run-level values are back
print("parent after:", flyte.ctx().custom_context)
return val
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/custom-context/override_context.py)
If the run was started with `{"trace_id": "root-abc"}`, this prints:
```bash
parent initial: {"trace_id": "root-abc"}
downstream sees: {"trace_id": "child-override"}
parent after: {"trace_id": "root-abc"}
```
Note that the override affected the nested downstream task because it was invoked while the `flyte.custom_context` block was active.
### Adding new keys for nested tasks
You can add keys (not just override):
```python
with flyte.custom_context(experiment="exp-blue", run_group="g-7"):
await some_task() # some_task sees both base keys + the new keys
```
## Accessing custom context
Always via the Flyte runtime:
```python
ctx = flyte.ctx().custom_context
value = ctx.get("key")
```
You can access the custom context using either `flyte.ctx().custom_context` or the shorthand `flyte.get_custom_context()`, which returns the same dictionary of key/value pairs.
Values are always strings, so parse as needed:
```python
timeout = int(ctx["timeout_seconds"])
```
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-programming/reports ===
# Reports
The reports feature allows you to display and update custom output in the UI during task execution.
First, you set the `report=True` flag in the task decorator. This enables the reporting feature for that task.
Within a task with reporting enabled, a **Build tasks > Reports > `flyte.report.Report`** object is created automatically.
A `Report` object contains one or more tabs, each of which contains HTML.
You can write HTML to an existing tab and create new tabs to organize your content.
Initially, the `Report` object has one tab (the default tab) with no content.
To write content:
- **Flyte SDK > Packages > flyte.report > Methods > log()** appends HTML content directly to the default tab.
- **Flyte SDK > Packages > flyte.report > Methods > replace()** replaces the content of the default tab with new HTML.
To get or create a new tab:
- **Build tasks > Reports > `flyte.report.get_tab()`** allows you to specify a unique name for the tab, and it will return the existing tab if it already exists or create a new one if it doesn't.
It returns a `flyte.report._report.Tab`
You can `log()` or `replace()` HTML on the `Tab` object just as you can directly on the `Report` object.
Finally, you send the report to the Flyte server and make it visible in the UI:
- **Flyte SDK > Packages > flyte.report > Methods > flush()** dispatches the report.
**It is important to call this method to ensure that the data is sent**.
## A simple example
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# ]
# main = "main"
# params = ""
# ///
import flyte
import flyte.report
env = flyte.TaskEnvironment(name="reports_example")
@env.task(report=True)
async def task1():
await flyte.report.replace.aio("
")
await flyte.report.flush.aio()
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(task1)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/simple.py)
Here we define a task `task1` that logs some HTML content to the default tab and creates a new tab named "Tab 2" where it logs additional HTML content.
The `flush` method is called to send the report to the backend.
## A more complex example
Here is another example.
We import the necessary modules, set up the task environment, define the main task with reporting enabled and define the data generation function:
```
import json
import random
import flyte
import flyte.report
env = flyte.TaskEnvironment(
name="globe_visualization",
)
@env.task(report=True)
async def generate_globe_visualization():
await flyte.report.replace.aio(get_html_content())
await flyte.report.flush.aio()
def generate_globe_data():
"""Generate sample data points for the globe"""
cities = [
{"city": "New York", "country": "USA", "lat": 40.7128, "lng": -74.0060},
{"city": "London", "country": "UK", "lat": 51.5074, "lng": -0.1278},
{"city": "Tokyo", "country": "Japan", "lat": 35.6762, "lng": 139.6503},
{"city": "Sydney", "country": "Australia", "lat": -33.8688, "lng": 151.2093},
{"city": "Paris", "country": "France", "lat": 48.8566, "lng": 2.3522},
{"city": "SΓ£o Paulo", "country": "Brazil", "lat": -23.5505, "lng": -46.6333},
{"city": "Mumbai", "country": "India", "lat": 19.0760, "lng": 72.8777},
{"city": "Cairo", "country": "Egypt", "lat": 30.0444, "lng": 31.2357},
{"city": "Moscow", "country": "Russia", "lat": 55.7558, "lng": 37.6176},
{"city": "Beijing", "country": "China", "lat": 39.9042, "lng": 116.4074},
{"city": "Lagos", "country": "Nigeria", "lat": 6.5244, "lng": 3.3792},
{"city": "Mexico City", "country": "Mexico", "lat": 19.4326, "lng": -99.1332},
{"city": "Bangkok", "country": "Thailand", "lat": 13.7563, "lng": 100.5018},
{"city": "Istanbul", "country": "Turkey", "lat": 41.0082, "lng": 28.9784},
{"city": "Buenos Aires", "country": "Argentina", "lat": -34.6118, "lng": -58.3960},
{"city": "Cape Town", "country": "South Africa", "lat": -33.9249, "lng": 18.4241},
{"city": "Dubai", "country": "UAE", "lat": 25.2048, "lng": 55.2708},
{"city": "Singapore", "country": "Singapore", "lat": 1.3521, "lng": 103.8198},
{"city": "Stockholm", "country": "Sweden", "lat": 59.3293, "lng": 18.0686},
{"city": "Vancouver", "country": "Canada", "lat": 49.2827, "lng": -123.1207},
]
categories = ["high", "medium", "low", "special"]
data_points = []
for city in cities:
data_point = {**city, "value": random.randint(10, 100), "category": random.choice(categories)}
data_points.append(data_point)
return data_points
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/globe_visualization.py)
We then define the HTML content for the report:
```python
def get_html_content():
data_points = generate_globe_data()
html_content = f"""
...
return html_content
"""
```
(We exclude it here due to length. You can find it in the [source file](https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/globe_visualization.py)).
Finally, we run the workflow:
```
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(generate_globe_visualization)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/globe_visualization.py)
When the workflow runs, the report will be visible in the UI:

## Streaming example
Above we demonstrated reports that are sent to the UI once, at the end of the task execution.
But, you can also stream updates to the report during task execution and see the display update in real-time.
You do this by calling `flyte.report.flush()` (or specifying `do_flush=True` in `flyte.report.log()`) periodically during the task execution, instead of just at the end of the task execution
> [!NOTE]
> In the above examples we explicitly call `flyte.report.flush()` to send the report to the UI.
> In fact, this is optional since flush will be called automatically at the end of the task execution.
> For streaming reports, on the other hand, calling `flush()` periodically (or specifying `do_flush=True`
> in `flyte.report.log()`) is necessary to display the updates.
First we import the necessary modules, and set up the task environment:
```
import asyncio
import json
import math
import random
import time
from datetime import datetime
from typing import List
import flyte
import flyte.report
env = flyte.TaskEnvironment(name="streaming_reports")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/streaming_reports.py)
Next we define the HTML content for the report:
```python
DATA_PROCESSING_DASHBOARD_HTML = """
...
"""
```
(We exclude it here due to length. You can find it in the [source file](
https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/streaming_reports.py)).
Finally, we define the task that renders the report (`data_processing_dashboard`), the driver task of the workflow (`main`), and the run logic:
```
@env.task(report=True)
async def data_processing_dashboard(total_records: int = 50000) -> str:
"""
Simulates a data processing pipeline with real-time progress visualization.
Updates every second for approximately 1 minute.
"""
await flyte.report.log.aio(DATA_PROCESSING_DASHBOARD_HTML, do_flush=True)
# Simulate data processing
processed = 0
errors = 0
batch_sizes = [800, 850, 900, 950, 1000, 1050, 1100] # Variable processing rates
start_time = time.time()
while processed < total_records:
# Simulate variable processing speed
batch_size = random.choice(batch_sizes)
# Add some processing delays occasionally
if random.random() < 0.1: # 10% chance of slower batch
batch_size = int(batch_size * 0.6)
await flyte.report.log.aio("""
""", do_flush=True)
elif random.random() < 0.05: # 5% chance of error
errors += random.randint(1, 5)
await flyte.report.log.aio("""
""", do_flush=True)
else:
await flyte.report.log.aio(f"""
""", do_flush=True)
processed = min(processed + batch_size, total_records)
current_time = time.time()
elapsed = current_time - start_time
rate = int(batch_size) if elapsed < 1 else int(processed / elapsed)
success_rate = ((processed - errors) / processed) * 100 if processed > 0 else 100
# Update dashboard
await flyte.report.log.aio(f"""
""", do_flush=True)
print(f"Processed {processed:,} records, Errors: {errors}, Rate: {rate:,}"
f" records/sec, Success Rate: {success_rate:.2f}%", flush=True)
await asyncio.sleep(1) # Update every second
if processed >= total_records:
break
# Final completion message
total_time = time.time() - start_time
avg_rate = int(total_records / total_time)
await flyte.report.log.aio(f"""
π Processing Complete!
Total Records: {total_records:,}
Processing Time: {total_time:.1f} seconds
Average Rate: {avg_rate:,} records/second
Success Rate: {success_rate:.2f}%
Errors Handled: {errors}
""", do_flush=True)
print(f"Data processing completed: {processed:,} records processed with {errors} errors.", flush=True)
return f"Processed {total_records:,} records successfully"
@env.task
async def main():
"""
Main task to run both reports.
"""
await data_processing_dashboard(total_records=50000)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/reports/streaming_reports.py)
The key to the live update ability is the `while` loop that appends Javascript to the report. The Javascript calls execute on append to the document and update it.
When the workflow runs, you can see the report updating in real-time in the UI:

=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-programming/notebooks ===
# Notebooks
Flyte is designed to work seamlessly with Jupyter notebooks, allowing you to write and execute workflows directly within a notebook environment.
## Iterating on and running a workflow
Download the following notebook file and open it in your favorite Jupyter environment: [interactive.ipynb](../../_static/public/interactive.ipynb)
In this example we have a simple workflow defined in our notebook.
You can iterate on the code in the notebook while running each cell in turn.
Note that the **Flyte SDK > Packages > flyte > Methods > init()** call at the top of the notebook looks like this:
```python
flyte.init(
endpoint="https://union.example.com",
org="example_org",
project="example_project",
domain="development",
)
```
You will have to adjust it to match your Union server endpoint, organization, project, and domain.
## Accessing runs and downloading logs
Similarly, you can download the following notebook file and open it in your favorite Jupyter environment: [remote.ipynb](../../_static/public/remote.ipynb)
In this example we use the **Flyte SDK > Packages > flyte.remote** package to list existing runs, access them, and download their details and logs.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-programming/error-handling ===
# Error handling
One of the key features of Flyte 2 is the ability to recover from user-level errors in a workflow execution.
This includes out-of-memory errors and other exceptions.
In a distributed system with heterogeneous compute, certain types of errors are expected and even, in a sense, acceptable.
Flyte 2 recognizes this and allows you to handle them gracefully as part of your workflow logic.
This ability is a direct result of the fact that workflows are now written in regular Python,
giving you with all the power and flexibility of Python error handling.
Let's look at an example:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# ]
# main = "main"
# params = ""
# ///
import asyncio
import flyte
import flyte.errors
env = flyte.TaskEnvironment(name="fail", resources=flyte.Resources(cpu=1, memory="250Mi"))
@env.task
async def oomer(x: int):
large_list = [0] * 100000000
print(len(large_list))
@env.task
async def always_succeeds() -> int:
await asyncio.sleep(1)
return 42
@env.task
async def main() -> int:
try:
await oomer(2)
except flyte.errors.OOMError as e:
print(f"Failed with oom trying with more resources: {e}, of type {type(e)}, {e.code}")
try:
await oomer.override(resources=flyte.Resources(cpu=1, memory="1Gi"))(5)
except flyte.errors.OOMError as e:
print(f"Failed with OOM Again giving up: {e}, of type {type(e)}, {e.code}")
raise e
finally:
await always_succeeds()
return await always_succeeds()
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(main)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/error-handling/error_handling.py)
In this code, we do the following:
* Import the necessary modules
* Set up the task environment. Note that we define our task environment with a resource allocation of 1 CPU and 250 MiB of memory.
* Define two tasks: one that will intentionally cause an out-of-memory (OOM) error, and another that will always succeed.
* Define the main task (the top level workflow task) that will handle the failure recovery logic.
The top `try...catch` block attempts to run the `oomer` task with a parameter that is likely to cause an OOM error.
If the error occurs, it catches the **Build tasks > Error handling > `flyte.errors.OOMError`** and attempts to run the `oomer` task again with increased resources.
This type of dynamic error handling allows you to gracefully recover from user-level errors in your workflows.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-programming/traces ===
# Traces
The `@flyte.trace` decorator provides fine-grained observability and resumption capabilities for functions called within your Flyte workflows.
Traces are used on **helper functions** that tasks call to perform specific operations like API calls, data processing, or computations.
Traces are particularly useful for **Considerations > Non-deterministic behavior**, allowing you to track execution details and resume from failures.
## What are traced functions for?
At the top level, Flyte workflows are composed of **tasks**. But it is also common practice to break down complex task logic into smaller, reusable functions by defining helper functions that tasks call to perform specific operations.
Any helper functions defined or imported into the same file as a task definition are automatically uploaded to the Flyte environment alongside the task when it is deployed.
At the task level, observability and resumption of failed executions is provided by caching, but what if you want these capabilities at a more granular level, for the individual operations that tasks perform?
This is where **traced functions** come in. By decorating helper functions with `@flyte.trace`, you enable:
- **Detailed observability**: Track execution time, inputs/outputs, and errors for each function call.
- **Fine-grained resumption**: If a workflow fails, resume from the last successful traced function instead of re-running the entire task.
Each traced function is effectively a checkpoint within its task.
Here is an example:
```
import asyncio
import flyte
env = flyte.TaskEnvironment("env")
@flyte.trace
async def call_llm(prompt: str) -> str:
await asyncio.sleep(0.1)
return f"LLM response for: {prompt}"
@flyte.trace
async def process_data(data: str) -> dict:
await asyncio.sleep(0.2)
return {"processed": data, "status": "completed"}
@env.task
async def research_workflow(topic: str) -> dict:
llm_result = await call_llm(f"Generate research plan for: {topic}")
processed_data = await process_data(llm_result)
return {"topic": topic, "result": processed_data}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/task_vs_trace.py)
## What Gets Traced
Traces capture detailed execution information:
- **Execution time**: How long each function call takes.
- **Inputs and outputs**: Function parameters and return values.
- **Checkpoints**: State that enables workflow resumption.
### Errors are not recorded
Only successful trace executions are recorded in the checkpoint system. When a traced function fails, the exception propagates up to your task code where you can handle it with standard error handling patterns.
### Supported Function Types
The trace decorator works with:
- **Asynchronous functions**: Functions defined with `async def`.
- **Generator functions**: Functions that `yield` values.
- **Async generators**: Functions that `async yield` values.
> [!NOTE]
> Currently tracing only works for asynchronous functions. Tracing of synchronous functions is coming soon.
```
@flyte.trace
async def async_api_call(topic: str) -> dict:
# Asynchronous API call
await asyncio.sleep(0.1)
return {"data": ["item1", "item2", "item3"], "status": "success"}
@flyte.trace
async def stream_data(items: list[str]):
# Async generator function for streaming
for item in items:
await asyncio.sleep(0.02)
yield f"Processing: {item}"
@flyte.trace
async def async_stream_llm(prompt: str):
# Async generator for streaming LLM responses
chunks = ["Research shows", " that machine learning", " continues to evolve."]
for chunk in chunks:
await asyncio.sleep(0.05)
yield chunk
@env.task
async def research_workflow(topic: str) -> dict:
llm_result = await async_api_call(topic)
# Collect async generator results
processed_data = []
async for item in stream_data(llm_result["data"]):
processed_data.append(item)
llm_stream = []
async for chunk in async_stream_llm(f"Summarize research on {topic}"):
llm_stream.append(chunk)
return {
"topic": topic,
"processed_data": processed_data,
"llm_summary": "".join(llm_stream)
}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/function_types.py)
## Task Orchestration Pattern
The typical Flyte workflow follows this pattern:
```
@flyte.trace
async def search_web(query: str) -> list[dict]:
# Search the web and return results
await asyncio.sleep(0.1)
return [{"title": f"Article about {query}", "content": f"Content on {query}"}]
@flyte.trace
async def summarize_content(content: str) -> str:
# Summarize content using LLM
await asyncio.sleep(0.1)
return f"Summary of {len(content.split())} words"
@flyte.trace
async def extract_insights(summaries: list[str]) -> dict:
# Extract insights from summaries
await asyncio.sleep(0.1)
return {"insights": ["key theme 1", "key theme 2"], "count": len(summaries)}
@env.task
async def research_pipeline(topic: str) -> dict:
# Each helper function creates a checkpoint
search_results = await search_web(f"research on {topic}")
summaries = []
for result in search_results:
summary = await summarize_content(result["content"])
summaries.append(summary)
final_insights = await extract_insights(summaries)
return {
"topic": topic,
"insights": final_insights,
"sources_count": len(search_results)
}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/pattern.py)
**Benefits of this pattern:**
- If `search_web` succeeds but `summarize_content` fails, resumption skips the search step
- Each operation is independently observable and debuggable
- Clear separation between workflow coordination (task) and execution (traced functions)
## Relationship to Caching and Checkpointing
Understanding how traces work with Flyte's other execution features:
| Feature | Scope | Purpose | Default Behavior |
|---------|-------|---------|------------------|
| **Task Caching** | Entire task execution (`@env.task`) | Skip re-running tasks with same inputs | Enabled (`cache="auto"`) |
| **Traces** | Individual helper functions | Observability and fine-grained resumption | Manual (requires `@flyte.trace`) |
| **Checkpointing** | Workflow state | Resume workflows from failure points | Automatic when traces are used |
### How They Work Together
```
@flyte.trace
async def traced_data_cleaning(dataset_id: str) -> List[str]:
# Creates checkpoint after successful execution.
await asyncio.sleep(0.2)
return [f"cleaned_record_{i}_{dataset_id}" for i in range(100)]
@flyte.trace
async def traced_feature_extraction(data: List[str]) -> dict:
# Creates checkpoint after successful execution.
await asyncio.sleep(0.3)
return {
"features": [f"feature_{i}" for i in range(10)],
"feature_count": len(data),
"processed_samples": len(data)
}
@flyte.trace
async def traced_model_training(features: dict) -> dict:
# Creates checkpoint after successful execution.
await asyncio.sleep(0.4)
sample_count = features["processed_samples"]
# Mock accuracy based on sample count
accuracy = min(0.95, 0.7 + (sample_count / 1000))
return {
"accuracy": accuracy,
"epochs": 50,
"model_size": "125MB"
}
@env.task(cache="auto") # Task-level caching enabled
async def data_pipeline(dataset_id: str) -> dict:
# 1. If this exact task with these inputs ran before,
# the entire task result is returned from cache
# 2. If not cached, execution begins and each traced function
# creates checkpoints for resumption
cleaned_data = await traced_data_cleaning(dataset_id) # Checkpoint 1
features = await traced_feature_extraction(cleaned_data) # Checkpoint 2
model_results = await traced_model_training(features) # Checkpoint 3
# 3. If workflow fails at step 3, resumption will:
# - Skip traced_data_cleaning (checkpointed)
# - Skip traced_feature_extraction (checkpointed)
# - Re-run only traced_model_training
return {"dataset_id": dataset_id, "accuracy": model_results["accuracy"]}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/traces/caching_vs_checkpointing.py)
### Execution Flow
1. **Task Submission**: Task is submitted with input parameters
2. **Cache Check**: Flyte checks if identical task execution exists in cache
3. **Cache Hit**: If cached, return cached result immediately (no traces needed)
4. **Cache Miss**: Begin fresh execution
5. **Trace Checkpoints**: Each `@flyte.trace` function creates resumption points
6. **Failure Recovery**: If workflow fails, resume from last successful checkpoint
7. **Task Completion**: Final result is cached for future identical inputs
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-programming/grouping-actions ===
# Grouping actions
Groups are an organizational feature in Flyte that allow you to logically cluster related task invocations (called "actions") for better visualization and management in the UI.
Groups help you organize task executions into manageable, hierarchical structures regardless of whether you're working with large fanouts or smaller, logically-related sets of operations.
## What are groups?
Groups provide a way to organize task invocations into logical units in the Flyte UI.
When you have multiple task executionsβwhether from large **Build tasks > Fanout**, sequential operations, or any combination of tasksβgroups help organize them into manageable units.
### The problem groups solve
Without groups, complex workflows can become visually overwhelming in the Flyte UI:
- Multiple task executions appear as separate nodes, making it hard to see the high-level structure
- Related operations are scattered throughout the workflow graph
- Debugging and monitoring becomes difficult when dealing with many individual task executions
Groups solve this by:
- **Organizing actions**: Multiple task executions within a group are presented as a hierarchical "folder" structure
- **Improving UI visualization**: Instead of many individual nodes cluttering the view, you see logical groups that can be collapsed or expanded
- **Aggregating status information**: Groups show aggregated run status (success/failure) of their contained actions when you hover over them in the UI
- **Maintaining execution parallelism**: Tasks still run concurrently as normal, but are organized for display
### How groups work
Groups are declared using the **Flyte SDK > Packages > flyte > Methods > group()** context manager.
Any task invocations that occur within the `with flyte.group()` block are automatically associated with that group:
```python
with flyte.group("my-group-name"):
# All task invocations here belong to "my-group-name"
result1 = await task_a(data)
result2 = await task_b(data)
result3 = await task_c(data)
```
The key points about groups:
1. **Context-based**: Use the `with flyte.group("name"):` context manager.
2. **Organizational tool**: Task invocations within the context are grouped together in the UI.
3. **UI folders**: Groups appear as collapsible/expandable folders in the Flyte UI run tree.
4. **Status aggregation**: Hover over a group in the UI to see aggregated success/failure information.
5. **Execution unchanged**: Tasks still execute in parallel as normal; groups only affect organization and visualization.
**Important**: Groups do not aggregate outputs. Each task execution still produces its own individual outputs. Groups are purely for organization and UI presentation.
## Common grouping patterns
### Sequential operations
Group related sequential operations that logically belong together:
```
@env.task
async def data_pipeline(raw_data: str) -> str:
with flyte.group("data-validation"):
validated_data = await process_data(raw_data, "validate_schema")
validated_data = await process_data(validated_data, "check_quality")
validated_data = await process_data(validated_data, "remove_duplicates")
with flyte.group("feature-engineering"):
features = await process_data(validated_data, "extract_features")
features = await process_data(features, "normalize_features")
features = await process_data(features, "select_features")
with flyte.group("model-training"):
model = await process_data(features, "train_model")
model = await process_data(model, "validate_model")
final_model = await process_data(model, "save_model")
return final_model
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py)
### Parallel processing with groups
Groups work well with parallel execution patterns:
```
@env.task
async def parallel_processing_example(n: int) -> str:
tasks = []
with flyte.group("parallel-processing"):
# Collect all task invocations first
for i in range(n):
tasks.append(process_item(i, "transform"))
# Execute all tasks in parallel
results = await asyncio.gather(*tasks)
# Convert to string for consistent return type
return f"parallel_results: {results}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py)
### Multi-phase workflows
Use groups to organize different phases of complex workflows:
```
@env.task
async def multi_phase_workflow(data_size: int) -> str:
# First phase: data preprocessing
preprocessed = []
with flyte.group("preprocessing"):
for i in range(data_size):
preprocessed.append(process_item(i, "preprocess"))
phase1_results = await asyncio.gather(*preprocessed)
# Second phase: main processing
processed = []
with flyte.group("main-processing"):
for result in phase1_results:
processed.append(process_item(result, "transform"))
phase2_results = await asyncio.gather(*processed)
# Third phase: postprocessing
postprocessed = []
with flyte.group("postprocessing"):
for result in phase2_results:
postprocessed.append(process_item(result, "postprocess"))
final_results = await asyncio.gather(*postprocessed)
# Convert to string for consistent return type
return f"multi_phase_results: {final_results}"
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py)
### Nested groups
Groups can be nested to create hierarchical organization:
```
@env.task
async def hierarchical_example(raw_data: str) -> str:
with flyte.group("machine-learning-pipeline"):
with flyte.group("data-preparation"):
cleaned_data = await process_data(raw_data, "clean_data")
split_data = await process_data(cleaned_data, "split_dataset")
with flyte.group("model-experiments"):
with flyte.group("hyperparameter-tuning"):
best_params = await process_data(split_data, "tune_hyperparameters")
with flyte.group("model-training"):
model = await process_data(best_params, "train_final_model")
return model
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py)
### Conditional grouping
Groups can be used with conditional logic:
```
@env.task
async def conditional_processing(use_advanced_features: bool, input_data: str) -> str:
base_result = await process_data(input_data, "basic_processing")
if use_advanced_features:
with flyte.group("advanced-features"):
enhanced_result = await process_data(base_result, "advanced_processing")
optimized_result = await process_data(enhanced_result, "optimize_result")
return optimized_result
else:
with flyte.group("basic-features"):
simple_result = await process_data(base_result, "simple_processing")
return simple_result
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/grouping-actions/grouping.py)
## Key insights
Groups are primarily an organizational and UI visualization toolβthey don't change how your tasks execute or aggregate their outputs, but they help organize related task invocations (actions) into collapsible folder-like structures for better workflow management and display. The aggregated status information (success/failure rates) is visible when hovering over group folders in the UI.
Groups make your Flyte workflows more maintainable and easier to understand, especially when working with complex workflows that involve multiple logical phases or large numbers of task executions. They serve as organizational "folders" in the UI's call stack tree, allowing you to collapse sections to reduce visual distraction while still seeing aggregated status information on hover.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-programming/fanout ===
# Fanout
Flyte is designed to scale effortlessly, allowing you to run workflows with large fan-outs.
When you need to execute many tasks in parallelβsuch as processing a large dataset or running hyperparameter sweepsβFlyte provides powerful patterns to implement these operations efficiently.
> [!NOTE]
> By default fanouts in Union are limited to a maximum size.
> Adjustment can made to this maximum by consulting with the Union team.
> Full documentation of this aspect of fanout is coming soon.
## Understanding fanout
A "fanout" pattern occurs when you spawn multiple tasks concurrently.
Each task runs in its own container and contributes an output that you later collect.
The most common way to implement this is using the [`asyncio.gather`](https://docs.python.org/3/library/asyncio-task.html#asyncio.gather) function.
In Flyte terminology, each individual task execution is called an "action"βthis represents a specific invocation of a task with particular inputs. When you call a task multiple times in a loop, you create multiple actions.
## Example
We start by importing our required packages, defining our Flyte environment, and creating a simple task that fetches user data from a mock API.
```
import asyncio
from typing import List, Tuple
import flyte
env = flyte.TaskEnvironment("fanout_env")
@env.task
async def fetch_data(user_id: int) -> dict:
"""Simulate fetching user data from an API - good for parallel execution."""
# Simulate network I/O delay
await asyncio.sleep(0.1)
return {
"user_id": user_id,
"name": f"User_{user_id}",
"score": user_id * 10,
"data": f"fetched_data_{user_id}"
}
# {{/docs-fragment setup}} }}
# {{docs-fragment parallel}}
@env.task
async def parallel_data_fetching(user_ids: List[int]) -> List[dict]:
"""Fetch data for multiple users in parallel - ideal for I/O bound operations."""
tasks = []
# Collect all fetch tasks - these can run in parallel since they're independent
for user_id in user_ids:
tasks.append(fetch_data(user_id))
# Execute all fetch operations in parallel
results = await asyncio.gather(*tasks)
return results
# {{/docs-fragment parallel}}
# {{docs-fragment run}}
if __name__ == "__main__":
flyte.init_from_config()
user_ids = [1, 2, 3, 4, 5]
r = flyte.run(parallel_data_fetching, user_ids)
print(r.name)
print(r.url)
r.wait()
# {{/docs-fragment run}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/fanout/fanout.py)
### Parallel execution
Next we implement the most common fanout pattern, which is to collect task invocations and execute them in parallel using `asyncio.gather()`:
```
@env.task
async def parallel_data_fetching(user_ids: List[int]) -> List[dict]:
"""Fetch data for multiple users in parallel - ideal for I/O bound operations."""
tasks = []
# Collect all fetch tasks - these can run in parallel since they're independent
for user_id in user_ids:
tasks.append(fetch_data(user_id))
# Execute all fetch operations in parallel
results = await asyncio.gather(*tasks)
return results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/fanout/fanout.py)
### Running the example
To actually run our example, we create a main guard that intializes Flyte and runs our main driver task:
```
if __name__ == "__main__":
flyte.init_from_config()
user_ids = [1, 2, 3, 4, 5]
r = flyte.run(parallel_data_fetching, user_ids)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/task-programming/fanout/fanout.py)
## How Flyte handles concurrency and parallelism
In the example we use a standard `asyncio.gather()` pattern.
When this pattern is used in a normal Python environment, the tasks would execute **concurrently** (cooperatively sharing a single thread through the event loop), but not in true **parallel** (multiple CPU cores simultaneously).
However, **Flyte transforms this concurrency model into true parallelism**. When you use `asyncio.gather()` in a Flyte task:
1. **Flyte acts as a distributed event loop**: Instead of scheduling coroutines on a single machine, Flyte schedules each task action to run in its own container across the cluster
2. **Concurrent becomes parallel**: What would be cooperative multitasking in regular Python becomes true parallel execution across multiple machines
3. **Native Python patterns**: You use familiar `asyncio` patterns, but Flyte automatically distributes the work
This means that when you write:
```python
results = await asyncio.gather(fetch_data(1), fetch_data(2), fetch_data(3))
```
Instead of three coroutines sharing one CPU, you get three separate containers running simultaneously, each with their own CPU, memory, and resources. Flyte seamlessly bridges the gap between Python's concurrency model and distributed parallel computing, allowing for massive scalability while maintaining the familiar async/await programming model.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-deployment ===
# Run and deploy tasks
You have seen how to configure and build the tasks that compose your project.
Now you need to decide how to execute them on your Flyte backend.
Flyte offers two distinct approaches for getting your tasks onto the backend:
**Use `flyte run` when you're iterating and experimenting:**
- Quickly test changes during development
- Try different parameters or code modifications
- Debug issues without creating permanent artifacts
- Prototype new ideas rapidly
**Use `flyte deploy` when your project is ready to be formalized:**
- Freeze a stable version of your tasks for repeated use
- Share tasks with team members or across environments
- Move from experimentation to a more structured workflow
- Create a permanent reference point (not necessarily production-ready)
This section explains both approaches and when to use each one.
## Ephemeral deployment and immediate execution
The `flyte run` CLI command and the `flyte.run()` SDK function are used to **ephemerally deploy** and **immediately execute** a task on the backend in a single step.
The task can be re-run and its execution and outputs can be observed in the **Runs list** UI, but it is not permanently added to the **Tasks list** on the backend.
Let's say you have the following file called `greeting.py`:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
```
### With the `flyte run` CLI command
The general form of the command for running a task from a local file is:
```bash
flyte run
```
So, to run the `greet` task defined in the `greeting.py` file, you would run:
```bash
flyte run greeting.py greet --message "Good morning!"
```
This command:
1. **Temporarily deploys** the task environment named `greeting_env` (held by the variable `env`) that contains the `greet` task.
2. **Executes** the `greet` function with argument `message` set to `"Good morning!"`. Note that `message` is the actual parameter name defined in the function signature.
3. **Returns** the execution results and displays them in the terminal.
### With the `flyte.run()` SDK function
You can also do the same thing programmatically using the `flyte.run()` function:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
if __name__ == "__main__":
flyte.init_from_config()
result = flyte.run(greet, message="Good morning!")
print(f"Result: {result}")
```
Here we add a `__main__` block to the `greeting.py` file that initializes the Flyte SDK from the configuration file and then calls `flyte.run()` with the `greet` task and its argument.
Now you can run the `greet` task on the backend just by executing the `greeting.py` file locally as a script:
```bash
python greeting.py
```
For more details on how `flyte run` and `flyte.run()` work under the hood, see **Run and deploy tasks > How task run works**.
## Persistent deployment
The `flyte deploy` CLI command and the `flyte.deploy()` SDK function are used to **persistently deploy** a task environment (and all its contained tasks) to the backend.
The tasks within the deployed environment will appear in the **Tasks list** UI on the backend and can then be executed multiple times without needing to redeploy them.
### With the `flyte deploy` CLI command
The general form of the command for running a task from a local file is:
```bash
flyte deploy
```
So, using the same `greeting.py` file as before, you can deploy the `greeting_env` task environment like this:
```bash
flyte deploy greeting.py env
```
This command deploys the task environment *assigned to the variable `env`* in the `greeting.py` file, which is the `TaskEnvironment` named `greeting_env`.
Notice that you must specify the *variable* to which the `TaskEnvironment` is assigned (`env` in this case), not the name of the environment itself (`greeting_env`).
Deploying a task environment deploys all tasks defined within it. Here, that means all functions decorated with `@env.task`.
In this case there is just one: `greet()`.
### With the `flyte.deploy()` SDK function
You can also do the same thing programmatically using the `flyte.deploy()` function:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
if __name__ == "__main__":
flyte.init_from_config()
deployments = flyte.deploy(env)
print(deployments[0].summary_repr())
```
Now you can deploy the `greeting_env` task environment (and therefore the `greet()` task) just by executing the `greeting.py` file locally as a script.
```bash
python greeting.py
```
For more details on how `flyte deploy` and `flyte.deploy()` work under the hood, see **Run and deploy tasks > How task deployment works**.
## Running already deployed tasks
If you have already deployed your task environment, you can run its tasks without redeploying by using the `flyte run` CLI command or the `flyte.run()` SDK function in a slightly different way. Alternatively, you can always initiate execution of a deployed task from the UI.
### With the `flyte run` CLI command
To run a permanently deployed task using the `flyte run` CLI command, use the special `deployed-task` keyword followed by the task reference in the format `{environment_name}.{task_name}`. For example, to run the previously deployed `greet` task from the `greeting_env` environment, you would run:
```bash
flyte run deployed-task greeting_env.greet --message "World"
```
Notice that now that the task environment is deployed, you use its name (`greeting_env`), not by the variable name to which it was assigned in source code (`env`).
The task environment name plus the task name (`greet`) are combined with a dot (`.`) to form the full task reference: `greeting_env.greet`.
The special `deployed-task` keyword tells the CLI that you are referring to a task that has already been deployed. In effect, it replaces the file path argument used for ephemeral runs.
When executed, this command will run the already-deployed `greet` task with argument `message` set to `"World"`. You will see the result printed in the terminal. You can also, of course, observe the execution in the **Runs list** UI.
### With the `flyte.run()` SDK function
You can also run already-deployed tasks programmatically using the `flyte.run()` function.
For example, to run the previously deployed `greet` task from the `greeting_env` environment, you would do:
```python
# greeting.py
import flyte
env = flyte.TaskEnvironment(name="greeting_env")
@env.task
async def greet(message: str) -> str:
return f"{message}!"
if __name__ == "__main__":
flyte.init_from_config()
flyte.deploy(env)
task = flyte.remote.Task.get("greeting_env.greet", auto_version="latest")
result = flyte.run(task, message="Good morning!")
print(f"Result: {result}")
```
When you execute this script locally, it will:
- Deploy the `greeting_env` task environment as before.
- Retrieve the already-deployed `greet` task using `flyte.remote.Task.get()`, specifying its full task reference as a string: `"greeting_env.greet"`.
- Call `flyte.run()` with the retrieved task and its argument.
For more details on how running already-deployed tasks works, see **Run and deploy tasks > How task run works > Running deployed tasks**.
## Subpages
- **Run and deploy tasks > How task run works**
- **Run and deploy tasks > Run command options**
- **Run and deploy tasks > How task deployment works**
- **Run and deploy tasks > Deploy command options**
- **Run and deploy tasks > Code packaging for remote execution**
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-deployment/how-task-run-works ===
# How task run works
The `flyte run` command and `flyte.run()` SDK function support three primary execution modes:
1. **Ephemeral deployment + run**: Automatically prepare task environments ephemerally and execute tasks (development shortcut)
2. **Run deployed task**: Execute permanently deployed tasks without redeployment
3. **Local execution**: Run tasks on your local machine for development and testing
Additionally, you can run deployed tasks through the Flyte/Union UI for interactive execution and monitoring.
## Ephemeral deployment + run: The development shortcut
The most common development pattern combines ephemeral task preparation and execution in a single command, automatically handling the temporary deployment process when needed.
### CLI: Ephemeral deployment and execution
```bash
# Basic deploy + run
flyte run my_example.py my_task --name "World"
# With explicit project and domain
flyte run --project my-project --domain development my_example.py my_task --name "World
# With deployment options
flyte run --version v1.0.0 --copy-style all my_example.py my_task --name "World"
```
**How it works:**
1. **Environment discovery**: Flyte loads the specified Python file and identifies task environments
2. **Ephemeral preparation**: Temporarily prepares the task environment for execution (similar to deployment but not persistent)
3. **Task execution**: Immediately runs the specified task with provided arguments in the ephemeral environment
4. **Result return**: Returns execution results and monitoring URL
5. **Cleanup**: The ephemeral environment is not stored permanently in the backend
### SDK: Programmatic ephemeral deployment + run
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def my_task(name: str) -> str:
return f"Hello, {name}!"
if __name__ == "__main__":
flyte.init_from_config()
# Deploy and run in one step
result = flyte.run(my_task, name="World")
print(f"Result: {result}")
print(f"Execution URL: {result.url}")
```
**Benefits of ephemeral deployment + run:**
- **Development efficiency**: No separate permanent deployment step required
- **Always current**: Uses your latest code changes without polluting the backend
- **Clean development**: Ephemeral environments don't clutter your task registry
- **Integrated workflow**: Single command for complete development cycle
## Running deployed tasks
For production workflows or when you want to use stable deployed versions, you can run tasks that have been **permanently deployed** with `flyte deploy` without triggering any deployment process.
### CLI: Running deployed tasks
```bash
# Run a previously deployed task
flyte run deployed-task my_env.my_task --name "World"
# With specific project/domain
flyte run --project prod --domain production deployed-task my_env.my_task --batch_size 1000
```
**Task reference format:** `{environment_name}.{task_name}`
- `environment_name`: The `name` property of your `TaskEnvironment`
- `task_name`: The function name of your task
>[!NOTE]
> Recall that when you deploy a task environment with `flyte deploy`, you specify the `TaskEnvironment` using the variable to which it is assigned.
> In contrast, once it is deployed, you refer to the environment by its `name` property.
### SDK: Running deployed tasks
```python
import flyte
flyte.init_from_config()
# Method 1: Using remote task reference
deployed_task = flyte.remote.Task.get("my_env.my_task", version="v1.0.0")
result = flyte.run(deployed_task, name="World")
# Method 2: Get latest version
deployed_task = flyte.remote.Task.get("my_env.my_task", auto_version="latest")
result = flyte.run(deployed_task, name="World")
```
**Benefits of running deployed tasks:**
- **Performance**: No deployment overhead, faster execution startup
- **Stability**: Uses tested, stable versions of your code
- **Production safety**: Isolated from local development changes
- **Version control**: Explicit control over which code version runs
## Local execution
For development, debugging, and testing, you can run tasks locally on your machine without any backend interaction.
### CLI: Local execution
```bash
# Run locally with --local flag
flyte run --local my_example.py my_task --name "World"
# Local execution with development data
flyte run --local data_pipeline.py process_data --input_path "/local/data" --debug true
```
### SDK: Local execution
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def my_task(name: str) -> str:
return f"Hello, {name}!"
# Method 1: No client configured (defaults to local)
result = flyte.run(my_task, name="World")
# Method 2: Explicit local mode
flyte.init_from_config() # Client configured
result = flyte.with_runcontext(mode="local").run(my_task, name="World")
```
**Benefits of local execution:**
- **Rapid development**: Instant feedback without network latency
- **Debugging**: Full access to local debugging tools
- **Offline development**: Works without backend connectivity
- **Resource efficiency**: Uses local compute resources
## Running tasks through the Union UI
If you are running your Flyte code on a Union backend, the UI provides an interactive way to run deployed tasks with form-based input and real-time monitoring.
### Accessing task execution in the Union UI
1. **Navigate to tasks**: Go to your project β domain β Tasks section
2. **Select task**: Choose the task environment and specific task
3. **Launch execution**: Click "Launch" to open the execution form
4. **Provide inputs**: Fill in task parameters through the web interface
5. **Monitor progress**: Watch real-time execution progress and logs
**UI execution benefits:**
- **User-friendly**: No command-line expertise required
- **Visual monitoring**: Real-time progress visualization
- **Input validation**: Built-in parameter validation and type checking
- **Execution history**: Easy access to previous runs and results
- **Sharing**: Shareable execution URLs for collaboration
Here is a short video demonstrating task execution through the Union UI:
πΊ [Watch on YouTube](https://www.youtube.com/watch?v=id="8jbau9yGoDg)
## Execution flow and architecture
### Fast registration architecture
Flyte v2 uses "fast registration" to enable rapid development cycles:
#### How it works
1. **Container images** contain the runtime environment and dependencies
2. **Code bundles** contain your Python source code (stored separately)
3. **At runtime**: Code bundles are downloaded and injected into running containers
#### Benefits
- **Rapid iteration**: Update code without rebuilding images
- **Resource efficiency**: Share images across multiple deployments
- **Version flexibility**: Run different code versions with same base image
- **Caching optimization**: Separate caching for images vs. code
#### When code gets injected
At task execution time, the fast registration process follows these steps:
1. **Container starts** with the base image containing runtime environment and dependencies
2. **Code bundle download**: The Flyte agent downloads your Python code bundle from storage
3. **Code extraction**: The code bundle is extracted and mounted into the running container
4. **Task execution**: Your task function executes with the injected code
### Ephemeral preparation logic
When using ephemeral deploy + run mode, Flyte determines whether temporary preparation is needed:
```mermaid
graph TD
A[flyte run command] --> B{Need preparation?}
B -->|Yes| C[Ephemeral preparation]
B -->|No| D[Use cached preparation]
C --> E[Execute task]
D --> E
E --> F[Cleanup ephemeral environment]
```
### Execution modes comparison
| Mode | Deployment | Performance | Use Case | Code Version |
|------|------------|-------------|-----------|--------------|
| Ephemeral Deploy + Run | Ephemeral (temporary) | Medium | Development, testing | Latest local |
| Run Deployed | None (uses permanent deployment) | Fast | Production, stable runs | Deployed version |
| Local | None | Variable | Development, debugging | Local |
| UI | None | Fast | Interactive, collaboration | Deployed version |
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-deployment/run-command-options ===
# Run command options
The `flyte run` command provides the following options:
**`flyte run [OPTIONS] |deployed_task `**
| Option | Short | Type | Default | Description |
|-----------------------------|-------|--------|---------------------------|--------------------------------------------------------|
| `--project` | `-p` | text | *from config* | Project to run tasks in |
| `--domain` | `-d` | text | *from config* | Domain to run tasks in |
| `--local` | | flag | `false` | Run the task locally |
| `--copy-style` | | choice | `loaded_modules|all|none` | Code bundling strategy |
| `--root-dir` | | path | *current dir* | Override source root directory |
| `--raw-data-path` | | text | | Override the output location for offloaded data types. |
| `--service-account` | | text | | Kubernetes service account. |
| `--name` | | text | | Name of the run. |
| `--follow` | `-f` | flag | `false` | Wait and watch logs for the parent action. |
| `--image` | | text | | Image to be used in the run (format: `name=uri`). |
| `--no-sync-local-sys-paths` | | flag | `false` | Disable synchronization of local sys.path entries. |
## `--project`, `--domain`
**`flyte run --domain --project |deployed_task `**
You can specify `--project` and `--domain` which will override any defaults defined in your `config.yaml`:
```bash
# Use defaults from default config.yaml
flyte run my_example.py my_task
# Specify target project and domain
flyte run --project my-project --domain development my_example.py my_task
```
## `--local`
**`flyte run --local `**
The `--local` option runs tasks locally instead of submitting them to the remote Flyte backend:
```bash
# Run task locally (default behavior when using flyte.run() without deployment)
flyte run --local my_example.py my_task --input "test_data"
# Compare with remote execution
flyte run my_example.py my_task --input "test_data" # Runs on Flyte backend
```
### When to use local execution
- **Development and testing**: Quick iteration without deployment overhead
- **Debugging**: Full access to local debugging tools and environment
- **Resource constraints**: When remote resources are unavailable or expensive
- **Data locality**: When working with large local datasets
## `--copy-style`
**`flyte run --copy-style [loaded_modules|all|none] `**
The `--copy-style` option controls code bundling for remote execution.
This applies to the ephemeral preparation step of the `flyte run` command and works similarly to `flyte deploy`:
```bash
# Smart bundling (default) - includes only imported project modules
flyte run --copy-style loaded_modules my_example.py my_task
# Include all project files
flyte run --copy-style all my_example.py my_task
# No code bundling (task must be pre-deployed)
flyte run --copy-style none deployed_task my_deployed_task
```
### Copy style options
- **`loaded_modules` (default)**: Bundles only imported Python modules from your project
- **`all`**: Includes all files in the project directory
- **`none`**: No bundling; requires permanently deployed tasks
## `--root-dir`
**`flyte run --root-dir `**
Override the source directory for code bundling and import resolution:
```bash
# Run from monorepo root with specific root directory
flyte run --root-dir ./services/ml ./services/ml/my_example.py my_task
# Handle cross-directory imports
flyte run --root-dir .. my_example.py my_workflow # When my_example.py imports sibling directories
```
This applies to the ephemeral preparation step of the `flyte run` command.
It works identically to the `flyte deploy` command's `--root-dir` option.
## `--raw-data-path`
**`flyte run --raw-data-path `**
Override the default output location for offloaded data types (large objects, DataFrames, etc.):
```bash
# Use custom S3 location for large outputs
flyte run --raw-data-path s3://my-bucket/custom-path/ my_example.py process_large_data
# Use local directory for development
flyte run --local --raw-data-path ./output/ my_example.py my_task
```
### Use cases
- **Custom storage locations**: Direct outputs to specific S3 buckets or paths
- **Cost optimization**: Use cheaper storage tiers for temporary data
- **Access control**: Ensure outputs go to locations with appropriate permissions
- **Local development**: Store large outputs locally when testing
## `--service-account`
**`flyte run --service-account `**
Specify a Kubernetes service account for task execution:
```bash
# Run with specific service account for cloud resource access
flyte run --service-account ml-service-account my_example.py train_model
# Use service account with specific permissions
flyte run --service-account data-reader-sa my_example.py load_data
```
### Use cases
- **Cloud resource access**: Service accounts with permissions for S3, GCS, etc.
- **Security isolation**: Different service accounts for different workload types
- **Compliance requirements**: Enforcing specific identity and access policies
## `--name`
**`flyte run --name `**
Provide a custom name for the execution run:
```bash
# Named execution for easy identification
flyte run --name "daily-training-run-2024-12-02" my_example.py train_model
# Include experiment parameters in name
flyte run --name "experiment-lr-0.01-batch-32" my_example.py hyperparameter_sweep
```
### Benefits of custom names
- **Easy identification**: Find specific runs in the Flyte console
- **Experiment tracking**: Include key parameters or dates in names
- **Automation**: Programmatically generate meaningful names for scheduled runs
## `--follow`
**`flyte run --follow `**
Wait and watch logs for the execution in real-time:
```bash
# Stream logs to console and wait for completion
flyte run --follow my_example.py long_running_task
# Combine with other options
flyte run --follow --name "training-session" my_example.py train_model
```
### Behavior
- **Log streaming**: Real-time output from task execution
- **Blocking execution**: Command waits until task completes
- **Exit codes**: Returns appropriate exit code based on task success/failure
## `--image`
**`flyte run --image `**
Override container images during ephemeral preparation, same as the equivalent `flyte deploy` option:
```bash
# Override specific named image
flyte run --image gpu=ghcr.io/org/gpu:v2.1 my_example.py gpu_task
# Override default image
flyte run --image ghcr.io/org/custom:latest my_example.py my_task
# Multiple image overrides
flyte run \
--image base=ghcr.io/org/base:v1.0 \
--image gpu=ghcr.io/org/gpu:v2.0 \
my_example.py multi_env_workflow
```
### Image mapping formats
- **Named mapping**: `name=uri` overrides images created with `Image.from_ref_name("name")`
- **Default mapping**: `uri` overrides the default "auto" image
- **Multiple mappings**: Use multiple `--image` flags for different image references
## `--no-sync-local-sys-paths`
**`flyte run --no-sync-local-sys-paths `**
Disable synchronization of local `sys.path` entries to the remote execution environment during ephemeral preparation.
Identical to the `flyte deploy` command's `--no-sync-local-sys-paths` option:
```bash
# Disable path synchronization for clean container environment
flyte run --no-sync-local-sys-paths my_example.py my_task
```
This advanced option works identically to the deploy command equivalent, useful for:
- **Container isolation**: Prevent local development paths from affecting remote execution
- **Custom environments**: When containers have pre-configured Python paths
- **Security**: Avoiding exposure of local directory structures
## Task argument passing
Arguments are passed directly as function parameters:
```bash
# CLI: Arguments as flags
flyte run my_file.py my_task --name "World" --count 5 --debug true
# SDK: Arguments as function parameters
result = flyte.run(my_task, name="World", count=5, debug=True)
```
## SDK options
The core `flyte run` functionality is also available programmatically through the `flyte.run()` function, with extensive configuration options available via the `flyte.with_runcontext()` function:
```python
# Run context configuration
result = flyte.with_runcontext(
mode="remote", # "remote", "local"
copy_style="loaded_modules", # Code bundling strategy
version="v1.0.0", # Ephemeral preparation version
dry_run=False, # Preview mode
).run(my_task, name="World")
```
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-deployment/how-task-deployment-works ===
# How task deployment works
In this section, we will take a deep dive into how the `flyte deploy` command and the `flyte.deploy()` SDK function work under the hood to deploy tasks to your Flyte backend.
When you perform a deployment, here's what happens:
## 1. Module loading and task environment discovery
In the first step, Flyte determines which files to load in order to search for task environments, based on the command line options provided:
### Single file (default)
```bash
flyte deploy my_example.py env
```
- The file `my_example.py` is executed,
- All declared `TaskEnvironment` objects in the file are instantiated,
but only the one assigned to the variable `env` is selected for deployment.
### `--all` option
```bash
flyte deploy --all my_example.py
```
- The file `my_example.py` is executed,
- All declared `TaskEnvironment` objects in the file are instantiated and selected for deployment.
- No specific variable name is required.
### `--recursive` option
```bash
flyte deploy --recursive ./directory
```
- The directory is recursively traversed and all Python files are executed and all `TaskEnvironment` objects are instantiated.
- All `TaskEnvironment` objects across all files are selected for deployment.
## 2. Task analysis and serialization
- For every task environment selected for deployment, all of its tasks are identified.
- Task metadata is extracted: parameter types, return types, and resource requirements.
- Each task is serialized into a Flyte `TaskTemplate`.
- Dependency graphs between environments are built (see below).
## 3. Task environment dependency resolution
In many cases, a task in one environment may invoke a task in another environment, establishing a dependency between the two environments.
For example, if `env_a` has a task that calls a task in `env_b`, then `env_a` depends on `env_b`.
This means that when deploying `env_a`, `env_b` must also be deployed to ensure that all tasks can be executed correctly.
To handle this, `TaskEnvironment`s can declare dependencies on other `TaskEnvironment`s using the `depends_on` parameter.
During deployment, the system performs the following steps to resolve these dependencies:
1. Starting with specified environment(s)
2. Recursively discovering all transitive dependencies
3. Including all dependencies in the deployment plan
4. Processing dependencies depth-first to ensure correct order
```python
# Define environments with dependencies
prep_env = flyte.TaskEnvironment(name="preprocessing")
ml_env = flyte.TaskEnvironment(name="ml_training", depends_on=[prep_env])
viz_env = flyte.TaskEnvironment(name="visualization", depends_on=[ml_env])
# Deploy only viz_env - automatically includes ml_env and prep_env
deployment = flyte.deploy(viz_env, version="v2.0.0")
# Or deploy multiple environments explicitly
deployment = flyte.deploy(data_env, ml_env, viz_env, version="v2.0.0")
```
For detailed information about working with multiple environments, see **Configure tasks > Multiple environments**.
## 4. Code bundle creation and upload
Once the task environments and their dependencies are resolved, Flyte proceeds to package your code into a bundle based on the `copy_style` option:
### `--copy_style loaded_modules` (default)
This is the smart bundling approach that analyzes which Python modules were actually imported during the task environment discovery phase.
It examines the runtime module registry (`sys.modules`) and includes only those modules that meet specific criteria:
they must have source files located within your project directory (not in system locations like `site-packages`), and they must not be part of the Flyte SDK itself.
This selective approach results in smaller, faster-to-upload bundles that contain exactly the code needed to run your tasks, making it ideal for most development and production scenarios.
### `--copy_style all`
This comprehensive bundling strategy takes a directory-walking approach, recursively traversing your entire project directory and including every file it encounters.
Unlike the smart bundling that only includes imported Python modules, this method captures all project files regardless of whether they were imported during discovery.
This is particularly useful for projects that use dynamic imports, load configuration files or data assets at runtime, or have dependencies that aren't captured through normal Python import mechanisms.
### `--copy_style none`
This option completely skips code bundle creation, meaning no source code is packaged or uploaded to cloud storage.
When using this approach, you must provide an explicit version parameter since there's no code bundle to generate a version from.
This strategy is designed for scenarios where your code is already baked into custom container images, eliminating the need for separate code injection during task execution.
It results in the fastest deployment times but requires more complex image management workflows.
### `--root-dir` option
By default, Flyte uses your current working directory as the root for code bundling.
You can override this with `--root-dir` to specify a different base directory - particularly useful for monorepos or when deploying from subdirectories. This affects all copy styles: `loaded_modules` will look for imported modules relative to the root directory, `all` will walk the directory tree starting from the root, and the root directory setting works with any copy style. See the **Run and deploy tasks > Deploy command options > `--root-dir`** for detailed usage examples.
After the code bundle is created (if applicable), it is uploaded to a cloud storage location (like S3 or GCS) accessible by your Flyte backend. It is now ready to be run.
## 5. Image building
If your `TaskEnvironment` specifies **Configure tasks > Container images**, Flyte builds and pushes container images before deploying tasks.
The build process varies based on your configuration and backend type:
### Local image building
When `image.builder` is set to `local` in **Getting started > Local setup**, images are built on your local machine using Docker. This approach:
- Requires Docker to be installed and running on your development machine
- Uses Docker BuildKit to build images from generated Dockerfiles or your custom Dockerfile
- Pushes built images to the container registry specified in your `Image` configuration
- Is the only option available for Flyte OSS instances
### Remote image building
When `image.builder` is set to `remote` in **Getting started > Local setup**, images are built on cloud infrastructure. This approach:
- Builds images using Union's ImageBuilder service (currently only available for Union backends, not OSS Flyte)
- Requires no local Docker installation or configuration
- Can push to Union's internal registry or external registries you specify
- Provides faster, more consistent builds by leveraging cloud resources
> [!NOTE]
> Remote building is currently exclusive to Union backends. OSS Flyte installations must use `local`
## Understanding option relationships
It's important to understand how the various deployment options work together.
The **discovery options** (`--recursive` and `--all`) operate independently of the **bundling options** (`--copy-style`),
giving you flexibility in how you structure your deployments.
Environment discovery determines which files Flyte will examine to find `TaskEnvironment` objects,
while code bundling controls what gets packaged and uploaded for execution.
You can freely combine these approaches.
For example, discovering environments recursively across your entire project while using smart bundling to include only the necessary code modules.
When multiple environments are discovered, they all share the same code bundle, which is efficient for related services or components that use common dependencies:
```bash
# All discovered environments share the same code bundle
flyte deploy --recursive --copy-style loaded_modules ./project
```
For a full overview of all deployment options, see **Flyte CLI > flyte > flyte deploy**.
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-deployment/deploy-command-options ===
# Deploy command options
The `flyte deploy` command provides extensive configuration options:
**`flyte deploy [OPTIONS] [TASK_ENV_VARIABLE]`**
| Option | Short | Type | Default | Description |
|-----------------------------|-------|--------|---------------------------|---------------------------------------------------|
| `--project` | `-p` | text | *from config* | Project to deploy to |
| `--domain` | `-d` | text | *from config* | Domain to deploy to |
| `--version` | | text | *auto-generated* | Explicit version tag for deployment |
| `--dry-run`/`--dryrun` | | flag | `false` | Preview deployment without executing |
| `--all` | | flag | `false` | Deploy all environments in specified path |
| `--recursive` | `-r` | flag | `false` | Deploy environments recursively in subdirectories |
| `--copy-style` | | choice | `loaded_modules|all|none` | Code bundling strategy |
| `--root-dir` | | path | *current dir* | Override source root directory |
| `--image` | | text | | Image URI mappings (format: `name=uri`) |
| `--ignore-load-errors` | `-i` | flag | `false` | Continue deployment despite module load failures |
| `--no-sync-local-sys-paths` | | flag | `false` | Disable local `sys.path` synchronization |
## `--project`, `--domain`
**`flyte deploy --domain --project `**
You can specify `--project` and `--domain` which will override any defaults defined in your `config.yaml`:
```bash
# Use defaults from default config.yaml
flyte deploy my_example.py env
# Specify target project and domain
flyte deploy --project my-project --domain development my_example.py env
```
## `--version`
**`flyte deploy --version `**
The `--version` option controls how deployed tasks are tagged and identified in the Flyte backend:
```bash
# Auto-generated version (default)
flyte deploy my_example.py env
# Explicit version
flyte deploy --version v1.0.0 my_example.py env
# Required when using copy-style none (no code bundle to generate hash from)
flyte deploy --copy-style none --version v1.0.0 my_example.py env
```
### When versions are used
- **Explicit versioning**: Provides human-readable task identification (e.g., `v1.0.0`, `prod-2024-12-01`)
- **Auto-generated versions**: When no version is specified, Flyte creates an MD5 hash from the code bundle, environment configuration, and image cache
- **Version requirement**: `copy-style none` mandates explicit versions since there's no code bundle to hash
- **Task referencing**: Versions enable precise task references in `flyte run deployed-task` and workflow invocations
## `--dry-run`
**`flyte deploy --dry-run `**
The `--dry-run` option allows you to preview what would be deployed without actually performing the deployment:
```bash
# Preview what would be deployed
flyte deploy --dry-run my_example.py env
```
## `--all` and `--recursive`
**`flyte deploy --all `**
**`flyte deploy --recursive `**
Control which environments get discovered and deployed:
**Single environment (default):**
```bash
# Deploy specific environment variable
flyte deploy my_example.py env
```
**All environments in file:**
```bash
# Deploy all TaskEnvironment objects in file
flyte deploy --all my_example.py
```
**Recursive directory deployment:**
```bash
# Deploy all environments in directory tree
flyte deploy --recursive ./src
# Combine with comprehensive bundling
flyte deploy --recursive --copy-style all ./project
```
## `--copy-style`
**`flyte deploy --copy_style [loaded_modules|all|none] `**
The `--copy-style` option controls what gets packaged:
### `--copy-style loaded_modules` (default)
```bash
flyte deploy --copy-style loaded_modules my_example.py env
```
- **Includes**: Only imported Python modules from your project
- **Excludes**: Site-packages, system modules, Flyte SDK
- **Best for**: Most projects (optimal size and speed)
### `--copy-style all`
```bash
flyte deploy --copy-style all my_example.py env
```
- **Includes**: All files in project directory
- **Best for**: Projects with dynamic imports or data files
### `--copy-style none`
```bash
flyte deploy --copy-style none --version v1.0.0 my_example.py env
```
- **Requires**: Explicit version parameter
- **Best for**: Pre-built container images with baked-in code
## `--root-dir`
**`flyte deploy --root-dir `**
The `--root-dir` option overrides the default source directory that Flyte uses as the base for code bundling and import resolution.
This is particularly useful for monorepos and projects with complex directory structures.
### Default behavior (without `--root-dir`)
- Flyte uses the current working directory as the root
- Code bundling starts from this directory
- Import paths are resolved relative to this location
### Common use cases
**Monorepos:**
```bash
# Deploy service from monorepo root
flyte deploy --root-dir ./services/ml ./services/ml/my_example.py env
# Deploy from anywhere in the monorepo
cd ./docs/
flyte deploy --root-dir ../services/ml ../services/ml/my_example.py env
```
**Cross-directory imports:**
```bash
# When workflow imports modules from sibling directories
# Project structure: project/workflows/my_example.py imports project/src/utils.py
cd project/workflows/
flyte deploy --root-dir .. my_example.py env # Sets root to project/
```
**Working directory independence:**
```bash
# Deploy from any location while maintaining consistent bundling
flyte deploy --root-dir /path/to/project /path/to/project/my_example.py env
```
### How it works
1. **Code bundling**: Files are collected starting from `--root-dir` instead of the current working directory
2. **Import resolution**: Python imports are resolved relative to the specified root directory
3. **Path consistency**: Ensures the same directory structure in local and remote execution environments
4. **Dependency packaging**: Captures all necessary modules that may be located outside the workflow file's immediate directory
### Example with complex project structure
```
my-project/
βββ services/
β βββ ml/
β β βββ my_example.py # imports shared.utils
β βββ api/
βββ shared/
βββ utils.py
```
```bash
# Deploy ML service workflows with access to shared utilities
flyte deploy --root-dir ./my-project ./my-project/services/ml/my_example.py env
```
This ensures that both `services/ml/` and `shared/` directories are included in the code bundle, allowing the workflow to successfully import `shared.utils` during remote execution.
## `--image`
**`flyte deploy --image `**
The `--image` option allows you to override image URIs at deployment time without modifying your code. Format: `imagename=imageuri`
### Named image mappings
```bash
# Map specific image reference to URI
flyte deploy --image base=ghcr.io/org/base:v1.0 my_example.py env
# Multiple named image mappings
flyte deploy \
--image base=ghcr.io/org/base:v1.0 \
--image gpu=ghcr.io/org/gpu:v2.0 \
my_example.py env
```
### Default image mapping
```bash
# Override default image (used when no specific image is set)
flyte deploy --image ghcr.io/org/default:latest my_example.py env
```
### How it works
- Named mappings (e.g., `base=URI`) override images created with `Image.from_ref_name("base")`.
- Unnamed mappings (e.g., just `URI`) override the default "auto" image.
- Multiple `--image` flags can be specified.
- Mappings are resolved during the image building phase of deployment.
## `--ignore-load-errors`
**`flyte deploy --ignore-load-errors `**
The `--ignore-load-errors` option allows the deployment process to continue even if some modules fail to load during the environment discovery phase. This is particularly useful for large projects or monorepos where certain modules may have missing dependencies or other issues that prevent them from being imported successfully.
```bash
# Continue deployment despite module failures
flyte deploy --recursive --ignore-load-errors ./large-project
```
## `--no-sync-local-sys-paths`
**`flyte deploy --no-sync-local-sys-paths `**
The `--no-sync-local-sys-paths` option disables the automatic synchronization of local `sys.path` entries to the remote container environment. This is an advanced option for specific deployment scenarios.
### Default behavior (path synchronization enabled)
- Flyte captures local `sys.path` entries that are under the root directory
- These paths are passed to the remote container via the `_F_SYS_PATH` environment variable
- At runtime, the remote container adds these paths to its `sys.path`, maintaining the same import environment
### When to disable path synchronization
```bash
# Disable local sys.path sync (advanced use case)
flyte deploy --no-sync-local-sys-paths my_example.py env
```
### Use cases for disabling
- **Custom container images**: When your container already has the correct `sys.path` configuration
- **Conflicting path structures**: When local development paths would interfere with container paths
- **Security concerns**: When you don't want to expose local development directory structures
- **Minimal environments**: When you want precise control over what gets added to the container's Python path
### How it works
- **Enabled (default)**: Local paths like `./my_project/utils` get synchronized and added to remote `sys.path`
- **Disabled**: Only the container's native `sys.path` is used, along with the deployed code bundle
Most users should leave path synchronization enabled unless they have specific requirements for container path isolation or are using pre-configured container environments.
## SDK deployment options
The core deployment functionality is available programmatically through the `flyte.deploy()` function, though some CLI-specific options are not applicable:
```python
import flyte
env = flyte.TaskEnvironment(name="my_env")
@env.task
async def process_data(data: str) -> str:
return f"Processed: {data}"
if __name__ == "__main__":
flyte.init_from_config()
# Comprehensive deployment configuration
deployment = flyte.deploy(
env, # Environment to deploy
dryrun=False, # Set to True for dry run
version="v1.2.0", # Explicit version tag
copy_style="loaded_modules" # Code bundling strategy
)
print(f"Deployment successful: {deployment[0].summary_repr()}")
```
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/task-deployment/packaging ===
# Code packaging for remote execution
When you run Flyte tasks remotely, your code needs to be available in the execution environment. Flyte SDK provides two main approaches for packaging your code:
1. **Code bundling** - Bundle code dynamically at runtime
2. **Container-based deployment** - Embed code directly in container images
## Quick comparison
| Aspect | Code bundling | Container-based |
|--------|---------------|-----------------|
| **Speed** | Fast (no image rebuild) | Slower (requires image build) |
| **Best for** | Rapid development, iteration | Production, immutable deployments |
| **Code changes** | Immediate effect | Requires image rebuild |
| **Setup** | Automatic by default | Manual configuration needed |
| **Reproducibility** | Excellent (hash-based versioning) | Excellent (immutable images) |
| **Rollback** | Requires version control | Tag-based, straightforward |
---
## Code bundling
**Default approach** - Automatically bundles and uploads your code to remote storage at runtime.
### How it works
When you run `flyte run` or call `flyte.run()`, Flyte automatically:
1. **Scans loaded modules** from your codebase
2. **Creates a tarball** (gzipped, without timestamps for consistent hashing)
3. **Uploads to blob storage** (S3, GCS, Azure Blob)
4. **Deduplicates** based on content hashes
5. **Downloads in containers** at runtime
This process happens transparently - every container downloads and extracts the code bundle before execution.
> [!NOTE]
> Code bundling is optimized for speed:
> - Bundles are created without timestamps for consistent hashing
> - Identical code produces identical hashes, enabling deduplication
> - Only modified code triggers new uploads
> - Containers cache downloaded bundles
>
> **Reproducibility:** Flyte automatically versions code bundles based on content hash. The same code always produces the same hash, guaranteeing reproducibility without manual versioning. However, version control is still recommended for rollback capabilities.
### Automatic code bundling
**Default behavior** - Bundles all loaded modules automatically.
#### What gets bundled
Flyte includes modules that are:
- β **Loaded when environment is parsed** (imported at module level)
- β **Part of your codebase** (not system packages)
- β **Within your project directory**
- β **NOT lazily loaded** (imported inside functions)
- β **NOT system-installed packages** (e.g., from site-packages)
#### Example: Basic automatic bundling
```python
# app.py
import flyte
from my_module import helper # β Bundled automatically
env = flyte.TaskEnvironment(
name="default",
image=flyte.Image.from_debian_base().with_pip_packages("pandas", "numpy")
)
@env.task
def process_data(x: int) -> int:
# This import won't be bundled (lazy load)
from another_module import util # β Not bundled automatically
return helper.transform(x)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(process_data, x=42)
print(run.url)
```
When you run this:
```bash
flyte run app.py process_data --x 42
```
Flyte automatically:
1. Bundles `app.py` and `my_module.py`
2. Preserves the directory structure
3. Uploads to blob storage
4. Makes it available in the remote container
#### Project structure example
```
my_project/
βββ app.py # Main entry point
βββ tasks/
β βββ __init__.py
β βββ data_tasks.py # Flyte tasks
β βββ ml_tasks.py
βββ utils/
βββ __init__.py
βββ preprocessing.py # Business logic
βββ models.py
```
```python
# app.py
import flyte
from tasks.data_tasks import load_data # β Bundled
from tasks.ml_tasks import train_model # β Bundled
# utils modules imported in tasks are also bundled
@flyte.task
def pipeline(dataset: str) -> float:
data = load_data(dataset)
accuracy = train_model(data)
return accuracy
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(pipeline, dataset="train.csv")
```
**All modules are bundled with their directory structure preserved.**
### Manual code bundling
Control exactly what gets bundled by configuring the copy style.
#### Copy styles
Three options available:
1. **`"auto"`** (default) - Bundle loaded modules only
2. **`"all"`** - Bundle everything in the working directory
3. **`"none"`** - Skip bundling entirely (requires code in container)
#### Using `copy_style="all"`
Bundle all files under your project directory:
```python
import flyte
flyte.init_from_config()
# Bundle everything in current directory
run = flyte.with_runcontext(copy_style="all").run(
my_task,
input_data="sample.csv"
)
```
Or via CLI:
```bash
flyte run --copy-style=all app.py my_task --input-data sample.csv
```
**Use when:**
- You have data files or configuration that tasks need
- You use dynamic imports or lazy loading
- You want to ensure all project files are available
#### Using `copy_style="none"`
Skip code bundling (see **Run and deploy tasks > Code packaging for remote execution > Container-based deployment**):
```python
run = flyte.with_runcontext(copy_style="none").run(my_task, x=10)
```
### Controlling the root directory
The `root_dir` parameter controls which directory serves as the bundling root.
#### Why root directory matters
1. **Determines what gets bundled** - All code paths are relative to root_dir
2. **Preserves import structure** - Python imports must match the bundle structure
3. **Affects path resolution** - Files and modules are located relative to root_dir
#### Setting root directory
##### Via CLI
```bash
flyte run --root-dir /path/to/project app.py my_task
```
##### Programmatically
```python
import pathlib
import flyte
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent
)
```
#### Root directory use cases
##### Use case 1: Multi-module project
```
project/
βββ src/
β βββ workflows/
β β βββ pipeline.py
β βββ utils/
β βββ helpers.py
βββ config.yaml
```
```python
# src/workflows/pipeline.py
import pathlib
import flyte
from utils.helpers import process # Relative import from project root
# Set root to project root (not src/)
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent.parent.parent
)
@flyte.task
def my_task():
return process()
```
**Root set to `project/` so imports like `from utils.helpers` work correctly.**
##### Use case 2: Shared utilities
```
workspace/
βββ shared/
β βββ common.py
βββ project/
βββ app.py
```
```python
# project/app.py
import flyte
import pathlib
from shared.common import shared_function # Import from parent directory
# Set root to workspace/ to include shared/
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent.parent
)
```
##### Use case 3: Monorepo
```
monorepo/
βββ libs/
β βββ data/
β βββ models/
βββ services/
βββ ml_service/
βββ workflows.py
```
```python
# services/ml_service/workflows.py
import flyte
import pathlib
from libs.data import loader # Import from monorepo root
from libs.models import predictor
# Set root to monorepo/ to include libs/
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent.parent.parent
)
```
#### Root directory best practices
1. **Set root_dir at project initialization** before importing any task modules
2. **Use absolute paths** with `pathlib.Path(__file__).parent` navigation
3. **Match your import structure** - if imports are relative to project root, set root_dir to project root
4. **Keep consistent** - use the same root_dir for both `flyte run` and `flyte.init()`
### Code bundling examples
#### Example: Standard Python package
```
my_package/
βββ pyproject.toml
βββ src/
β βββ my_package/
β βββ __init__.py
β βββ main.py
β βββ data/
β β βββ loader.py
β β βββ processor.py
β βββ models/
β βββ analyzer.py
```
```python
# src/my_package/main.py
import flyte
import pathlib
from my_package.data.loader import fetch_data
from my_package.data.processor import clean_data
from my_package.models.analyzer import analyze
env = flyte.TaskEnvironment(
name="pipeline",
image=flyte.Image.from_debian_base().with_uv_project(
pyproject_file=pathlib.Path(__file__).parent.parent.parent / "pyproject.toml"
)
)
@env.task
async def fetch_task(url: str) -> dict:
return await fetch_data(url)
@env.task
def process_task(raw_data: dict) -> list[dict]:
return clean_data(raw_data)
@env.task
def analyze_task(data: list[dict]) -> str:
return analyze(data)
if __name__ == "__main__":
import flyte.git
# Set root to project root for proper imports
flyte.init_from_config(
flyte.git.config_from_root(),
root_dir=pathlib.Path(__file__).parent.parent.parent
)
# All modules bundled automatically
run = flyte.run(analyze_task, data=[{"value": 1}, {"value": 2}])
print(f"Run URL: {run.url}")
```
**Run with:**
```bash
cd my_package
flyte run src/my_package/main.py analyze_task --data '[{"value": 1}]'
```
#### Example: Dynamic environment based on domain
```python
# environment_picker.py
import flyte
def create_env():
"""Create different environments based on domain."""
if flyte.current_domain() == "development":
return flyte.TaskEnvironment(
name="dev",
image=flyte.Image.from_debian_base(),
env_vars={"ENV": "dev", "DEBUG": "true"}
)
elif flyte.current_domain() == "staging":
return flyte.TaskEnvironment(
name="staging",
image=flyte.Image.from_debian_base(),
env_vars={"ENV": "staging", "DEBUG": "false"}
)
else: # production
return flyte.TaskEnvironment(
name="prod",
image=flyte.Image.from_debian_base(),
env_vars={"ENV": "production", "DEBUG": "false"},
resources=flyte.Resources(cpu="2", memory="4Gi")
)
env = create_env()
@env.task
async def process(n: int) -> int:
import os
print(f"Running in {os.getenv('ENV')} environment")
return n * 2
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(process, n=5)
print(run.url)
```
**Why this works:**
- `flyte.current_domain()` is set correctly when Flyte re-instantiates modules remotely
- Environment configuration is deterministic and reproducible
- Code automatically bundled with domain-specific settings
> [!NOTE]
> `flyte.current_domain()` only works after `flyte.init()` is called:
> - β Works with `flyte run` and `flyte deploy` (auto-initialize)
> - β Works in `if __name__ == "__main__"` after explicit `flyte.init()`
> - β Does NOT work at module level without initialization
### When to use code bundling
β **Use code bundling when:**
- Rapid development and iteration
- Frequently changing code
- Multiple developers testing changes
- Jupyter notebook workflows
- Quick prototyping and experimentation
β **Consider container-based instead when:**
- Need easy rollback to previous versions (container tags are simpler than finding git commits)
- Working with air-gapped environments (no blob storage access)
- Code changes require coordinated dependency updates
---
## Container-based deployment
**Advanced approach** - Embed code directly in container images for immutable deployments.
### How it works
Instead of bundling code at runtime:
1. **Build container image** with code copied inside
2. **Disable code bundling** with `copy_style="none"`
3. **Container has everything** needed at runtime
**Trade-off:** Every code change requires a new image build (slower), but provides complete reproducibility.
### Configuration
Three key steps:
#### 1. Set `copy_style="none"`
Disable runtime code bundling:
```python
flyte.with_runcontext(copy_style="none").run(my_task, n=10)
```
Or via CLI:
```bash
flyte run --copy-style=none app.py my_task --n 10
```
#### 2. Copy Code into Image
Use `Image.with_source_file()` or `Image.with_source_folder()`:
```python
import pathlib
import flyte
env = flyte.TaskEnvironment(
name="embedded",
image=flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent,
copy_contents_only=True
)
)
```
#### 3. Set Correct `root_dir`
Match your image copy configuration:
```python
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent
)
```
### Image source copying methods
#### `with_source_file()` - Copy individual files
Copy a single file into the container:
```python
image = flyte.Image.from_debian_base().with_source_file(
src=pathlib.Path(__file__),
dst="/app/main.py"
)
```
**Use for:**
- Single-file workflows
- Copying configuration files
- Adding scripts to existing images
#### `with_source_folder()` - Copy directories
Copy entire directories into the container:
```python
image = flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent,
dst="/app",
copy_contents_only=False # Copy folder itself
)
```
**Parameters:**
- `src`: Source directory path
- `dst`: Destination path in container (optional, defaults to workdir)
- `copy_contents_only`: If `True`, copies folder contents; if `False`, copies folder itself
##### `copy_contents_only=True` (Recommended)
Copies only the contents of the source folder:
```python
# Project structure:
# my_project/
# βββ app.py
# βββ utils.py
image = flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent,
copy_contents_only=True
)
# Container will have:
# /app/app.py
# /app/utils.py
# Set root_dir to match:
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
```
##### `copy_contents_only=False`
Copies the folder itself with its name:
```python
# Project structure:
# workspace/
# βββ my_project/
# βββ app.py
# βββ utils.py
image = flyte.Image.from_debian_base().with_source_folder(
src=pathlib.Path(__file__).parent, # Points to my_project/
copy_contents_only=False
)
# Container will have:
# /app/my_project/app.py
# /app/my_project/utils.py
# Set root_dir to parent to match:
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent.parent)
```
### Complete container-based example
```python
# full_build.py
import pathlib
import flyte
from dep import helper # Local module
# Configure environment with source copying
env = flyte.TaskEnvironment(
name="full_build",
image=flyte.Image.from_debian_base()
.with_pip_packages("numpy", "pandas")
.with_source_folder(
src=pathlib.Path(__file__).parent,
copy_contents_only=True
)
)
@env.task
def square(x: int) -> int:
return x ** helper.get_exponent()
@env.task
def main(n: int) -> list[int]:
return list(flyte.map(square, range(n)))
if __name__ == "__main__":
import flyte.git
# Initialize with matching root_dir
flyte.init_from_config(
flyte.git.config_from_root(),
root_dir=pathlib.Path(__file__).parent
)
# Run with copy_style="none" and explicit version
run = flyte.with_runcontext(
copy_style="none",
version="v1.0.0" # Explicit version for image tagging
).run(main, n=10)
print(f"Run URL: {run.url}")
run.wait()
```
**Project structure:**
```
project/
βββ full_build.py
βββ dep.py # Local dependency
βββ .flyte/
βββ config.yaml
```
**Run with:**
```bash
python full_build.py
```
This will:
1. Build a container image with `full_build.py` and `dep.py` embedded
2. Tag it as `v1.0.0`
3. Push to registry
4. Execute remotely without code bundling
### Using externally built images
When containers are built outside of Flyte (e.g., in CI/CD), use `Image.from_ref_name()`:
#### Step 1: Build your image externally
```dockerfile
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
# Copy your code
COPY src/ /app/
# Install dependencies
RUN pip install flyte pandas numpy
# Ensure flyte executable is available
RUN flyte --help
```
```bash
# Build in CI/CD
docker build -t myregistry.com/my-app:v1.2.3 .
docker push myregistry.com/my-app:v1.2.3
```
#### Step 2: Reference image by name
```python
# app.py
import flyte
env = flyte.TaskEnvironment(
name="external",
image=flyte.Image.from_ref_name("my-app-image") # Reference name
)
@env.task
def process(x: int) -> int:
return x * 2
if __name__ == "__main__":
flyte.init_from_config()
# Pass actual image URI at deploy/run time
run = flyte.with_runcontext(
copy_style="none",
images={"my-app-image": "myregistry.com/my-app:v1.2.3"}
).run(process, x=10)
```
Or via CLI:
```bash
flyte run \
--copy-style=none \
--image my-app-image=myregistry.com/my-app:v1.2.3 \
app.py process --x 10
```
**For deployment:**
```bash
flyte deploy \
--image my-app-image=myregistry.com/my-app:v1.2.3 \
app.py
```
#### Why use reference names?
1. **Decouples code from image URIs** - Change images without modifying code
2. **Supports multiple environments** - Different images for dev/staging/prod
3. **Integrates with CI/CD** - Build images in pipelines, reference in code
4. **Enables image reuse** - Multiple tasks can reference the same image
#### Example: Multi-environment deployment
```python
import flyte
import os
# Code references image by name
env = flyte.TaskEnvironment(
name="api",
image=flyte.Image.from_ref_name("api-service")
)
@env.task
def api_call(endpoint: str) -> dict:
# Implementation
return {"status": "success"}
if __name__ == "__main__":
flyte.init_from_config()
# Determine image based on environment
environment = os.getenv("ENV", "dev")
image_uri = {
"dev": "myregistry.com/api-service:dev",
"staging": "myregistry.com/api-service:staging",
"prod": "myregistry.com/api-service:v1.2.3"
}[environment]
run = flyte.with_runcontext(
copy_style="none",
images={"api-service": image_uri}
).run(api_call, endpoint="/health")
```
### Container-based best practices
1. **Always set explicit versions** when using `copy_style="none"`:
```python
flyte.with_runcontext(copy_style="none", version="v1.0.0")
```
2. **Match `root_dir` to `copy_contents_only`**:
- `copy_contents_only=True` β `root_dir=Path(__file__).parent`
- `copy_contents_only=False` β `root_dir=Path(__file__).parent.parent`
3. **Ensure `flyte` executable is in container** - Add to PATH or install flyte package
4. **Use `.dockerignore`** to exclude unnecessary files:
```
# .dockerignore
__pycache__/
*.pyc
.git/
.venv/
*.egg-info/
```
5. **Test containers locally** before deploying:
```bash
docker run -it myimage:latest /bin/bash
python -c "import mymodule" # Verify imports work
```
### When to use container-based deployment
β **Use container-based when:**
- Deploying to production
- Need immutable, reproducible environments
- Working with complex system dependencies
- Deploying to air-gapped or restricted environments
- CI/CD pipelines with automated builds
- Code changes are infrequent
β **Don't use container-based when:**
- Rapid development and frequent code changes
- Quick prototyping
- Interactive development (Jupyter notebooks)
- Learning and experimentation
---
## Choosing the right approach
### Decision tree
```
Are you iterating quickly on code?
ββ Yes β Use Code Bundling (Default)
β (Development, prototyping, notebooks)
β Both approaches are fully reproducible via hash/tag
ββ No β Do you need easy version rollback?
ββ Yes β Use Container-based
β (Production, CI/CD, straightforward tag-based rollback)
ββ No β Either works
(Code bundling is simpler, container-based for air-gapped)
```
### Hybrid approach
You can use different approaches for different tasks:
```python
import flyte
import pathlib
# Fast iteration for development tasks
dev_env = flyte.TaskEnvironment(
name="dev",
image=flyte.Image.from_debian_base().with_pip_packages("pandas")
# Code bundling (default)
)
# Immutable containers for production tasks
prod_env = flyte.TaskEnvironment(
name="prod",
image=flyte.Image.from_debian_base()
.with_pip_packages("pandas")
.with_source_folder(pathlib.Path(__file__).parent, copy_contents_only=True)
# Requires copy_style="none"
)
@dev_env.task
def experimental_task(x: int) -> int:
# Rapid development with code bundling
return x * 2
@prod_env.task
def stable_task(x: int) -> int:
# Production with embedded code
return x ** 2
if __name__ == "__main__":
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
# Use code bundling for dev task
dev_run = flyte.run(experimental_task, x=5)
# Use container-based for prod task
prod_run = flyte.with_runcontext(
copy_style="none",
version="v1.0.0"
).run(stable_task, x=5)
```
---
## Troubleshooting
### Import errors
**Problem:** `ModuleNotFoundError` when task executes remotely
**Solutions:**
1. **Check loaded modules** - Ensure modules are imported at module level:
```python
# β Good - bundled automatically
from mymodule import helper
@flyte.task
def my_task():
return helper.process()
```
```python
# β Bad - not bundled (lazy load)
@flyte.task
def my_task():
from mymodule import helper
return helper.process()
```
2. **Verify `root_dir`** matches your import structure:
```python
# If imports are: from mypackage.utils import foo
# Then root_dir should be parent of mypackage/
flyte.init_from_config(root_dir=pathlib.Path(__file__).parent.parent)
```
3. **Use `copy_style="all"`** to bundle everything:
```bash
flyte run --copy-style=all app.py my_task
```
### Code changes not reflected
**Problem:** Remote execution uses old code despite local changes
> [!NOTE]
> This is rare with code bundling - Flyte automatically versions based on content hash, so code changes should be detected automatically. This issue typically occurs with caching problems or when using `copy_style="none"`.
**Solutions:**
1. **Use explicit version bump** (mainly for container-based deployments):
```python
run = flyte.with_runcontext(version="v2").run(my_task)
```
2. **Check if `copy_style="none"`** is set - this requires image rebuild:
```python
# If using copy_style="none", rebuild image
run = flyte.with_runcontext(
copy_style="none",
version="v2" # Bump version to force rebuild
).run(my_task)
```
### Files missing in container
**Problem:** Task can't find data files or configs
**Solutions:**
1. **Use `copy_style="all"`** to bundle all files:
```bash
flyte run --copy-style=all app.py my_task
```
2. **Copy files explicitly in image**:
```python
image = flyte.Image.from_debian_base().with_source_file(
src=pathlib.Path("config.yaml"),
dst="/app/config.yaml"
)
```
3. **Store data in remote storage** instead of bundling:
```python
@flyte.task
def my_task():
# Read from S3/GCS instead of local files
import flyte.io
data = flyte.io.File("s3://bucket/data.csv").open().read()
```
### Container build failures
**Problem:** Image build fails with `copy_style="none"`
**Solutions:**
1. **Check `root_dir` matches `copy_contents_only`**:
```python
# copy_contents_only=True
image = Image.from_debian_base().with_source_folder(
src=Path(__file__).parent,
copy_contents_only=True
)
flyte.init(root_dir=Path(__file__).parent) # Match!
```
2. **Ensure `flyte` executable available**:
```python
image = Image.from_debian_base() # Has flyte pre-installed
```
3. **Check file permissions** in source directory:
```bash
chmod -R +r project/
```
### Version conflicts
**Problem:** Multiple versions of same image causing confusion
**Solutions:**
1. **Use explicit versions**:
```python
run = flyte.with_runcontext(
copy_style="none",
version="v1.2.3" # Explicit, not auto-generated
).run(my_task)
```
2. **Clean old images**:
```bash
docker image prune -a
```
3. **Use semantic versioning** for clarity:
```python
version = "v1.0.0" # Major.Minor.Patch
```
---
## Further reading
- **Run and deploy tasks > Code packaging for remote execution > Image API Reference** - Complete Image class documentation
- **Run and deploy tasks > Code packaging for remote execution > TaskEnvironment** - Environment configuration options
- [Configuration Guide](./configuration/) - Setting up Flyte config files
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/authenticating ===
# Authenticating with Union
Union supports three authentication modes to suit different environments and use cases. This guide will help you choose the right authentication method and configure it correctly.
## Quick start
For most users getting started with Union:
1. Create a configuration file:
```bash
flyte create config --endpoint https://your-endpoint.unionai.cloud
```
Optionally, you can also add a default project and domain.
```bash
flyte create config --endpoint http://your-endpoint.unionai.cloud --project flytesnacks --domain development
```
2. Run any command to authenticate:
```bash
flyte get project
```
This will automatically open your browser to complete authentication.
## Authentication modes
### PKCE authentication (browser-based) {#pkce}
**Default mode** - Uses OAuth2 PKCE flow with automatic browser authentication.
#### When to use
- Interactive development on your laptop or workstation
- Jupyter notebooks running locally or on machines with browser access
- Any environment where you can open a web browser
#### How it works
When you run any Flyte command, Union automatically:
1. Opens your default web browser
2. Prompts you to authenticate
3. Stores credentials that auto-refresh every few hours
#### Configuration
This is the default authentication type when you create a configuration from the `flyte create config` command. The generated file has the effect of:
```yaml
admin:
endpoint: dns:///your-endpoint.hosted.unionai.cloud
authType: Pkce
insecure: false
```
Since the PKCE method is default, it's omitted from the generated file, as is disabling SSL.
#### CLI usage
Simply run any command - authentication happens automatically:
```bash
flyte get project
flyte run app.py main
flyte deploy app.py
```
#### Programmatic usage
```python
import flyte
import flyte.remote as remote
# Initialize with PKCE authentication (default)
flyte.init(endpoint="dns:///your-endpoint.hosted.unionai.cloud")
print([t for t in remote.Task.listall(project="flytesnacks", domain="development")])
```
If your configuration file is accessible, you can also initialize with `init_from_config`:
```python
import flyte
flyte.init_from_config("/path/to/config.yaml")
```
Or omitting if you just want to pick up from the default locations.
```python
flyte.init_from_config()
```
### Device flow authentication {#device-flow}
**For headless or browser-restricted environments** - Uses OAuth2 device flow with code verification.
#### When to use
- Remote servers without GUI/browser access
- Hosted notebook environments (Google Colab, AWS SageMaker, Azure ML)
- SSH sessions or terminal-only environments
- Docker containers where browser redirect isn't possible
#### How it works
When you run a command, Union displays a URL and user code. You:
1. Open the URL on any browser (on any device)
2. Enter the displayed code
3. Complete authentication
4. Return to your terminal - the session is now authenticated
#### Configuration
Create or update your config to use device flow:
```bash
flyte create config --endpoint http://your-endpoint.unionai.cloud --auth-type headless
```
Your config file will contain:
```yaml
admin:
authType: DeviceFlow
endpoint: dns:///your-endpoint.hosted.unionai.cloud
```
#### CLI usage
When you run a command, you'll see:
```bash
$ flyte get app
To Authenticate, navigate in a browser to the following URL:
https://signin.hosted.unionai.cloud/activate?user_code=TKBJXFFW
```
Open that URL on any device with a browser, enter the code, and authentication completes.
#### Programmatic usage
```python
import flyte
# Initialize with device flow authentication
flyte.init(endpoint="dns:///your-union-endpoint", headless=True)
# Your code here
@flyte.task
def my_task():
return "Hello Union!"
```
**Example: Google Colab**
```python
# In a Colab notebook
import flyte
# This will display a URL and code in the cell output
flyte.init(
endpoint="dns:///your-union-endpoint",
headless=True
)
# Define and run your workflows
@flyte.task
def process_data(data: str) -> str:
return f"Processed: {data}"
```
### API key authentication (OAuth2 app credentials) {#api-key}
**For automated and CI/CD environments** - Uses OAuth2 client credentials encoded as an API key.
#### When to use
- CI/CD pipelines (GitHub Actions, GitLab CI, Jenkins)
- Automated deployment scripts
- Production workloads
- Any non-interactive environment
- Service-to-service authentication
#### How it works
Union encodes OAuth2 client credentials (client ID and client secret) into a single API key string. This key contains all information needed to connect to Union, including the endpoint.
> [!NOTE]
> **Security Note:** API keys are sensitive credentials. Treat them like passwords:
> - Store them in secret management systems (GitHub Secrets, AWS Secrets Manager, etc.)
> - Never commit them to version control
> - Rotate them regularly
> - Use different keys for different environments
#### Setup
1. Install the Union plugin:
```bash
pip install flyteplugins-union
```
2. Ensure the API key is there:
```bash
flyte get api-key my-ci-key
```
3. Store this key securely (e.g., in GitHub Secrets, secret manager)
#### Managing API keys
List existing keys:
```bash
flyte get api-key
```
Delete a key:
```bash
flyte delete api-key my-ci-key
```
#### Programmatic usage
```python
import flyte
import os
# Initialize with API key - endpoint is embedded in the key
api_key = os.getenv("FLYTE_API_KEY")
flyte.init(api_key=api_key)
```
**Example: Automated Script**
```python
#!/usr/bin/env python3
import flyte
import os
# Read API key from environment
api_key = os.getenv("FLYTE_API_KEY")
if not api_key:
raise ValueError("FLYTE_API_KEY environment variable not set")
# Initialize - no endpoint needed
flyte.init(api_key=api_key)
# Deploy or run workflows
@flyte.task
def automated_task():
return "Deployed from automation"
```
## Comparison table
| Feature | PKCE | Device Flow | API Key |
|---------|------|-------------|---------|
| **Environment** | Browser available | Headless/remote | Fully automated |
| **Authentication** | Automatic browser | Manual code entry | Non-interactive |
| **Token refresh** | Automatic | Automatic | Automatic |
| **Best for** | Local development | Remote notebooks | CI/CD, production |
| **Setup complexity** | Minimal | Minimal | Moderate (requires plugin) |
| **Security** | User credentials | User credentials | App credentials |
## Switching between authentication modes
You can switch authentication modes by updating your config file:
```bash
# Switch to PKCE
flyte create config --endpoint dns:///your-endpoint.hosted.unionai.cloud
# Switch to device flow
flyte create config --endpoint dns:///your-endpoint.hosted.unionai.cloud --auth-type headless
```
Or manually edit your `~/.flyte/config.yaml`:
```yaml
admin:
authType: Pkce # or DeviceFlow
endpoint: dns:///your-union-endpoint
```
## Troubleshooting
### Browser doesn't open for PKCE
If the browser doesn't open automatically:
1. Copy the URL shown in your terminal
2. Open it manually in your browser
3. Complete the authentication flow
Alternatively, switch to device flow if you're in a headless environment.
### Device flow code expires
Device flow codes typically expire after a few minutes. If your code expires:
1. Run the command again to get a new code
2. Authenticate more quickly
### API key doesn't work
Ensure you've installed the required plugin:
```bash
pip install flyteplugins-union
```
Verify your API key is set correctly:
```bash
echo $FLYTE_API_KEY
```
## Best practices
1. **Local development**: Use PKCE authentication for the best experience
2. **Remote development**: Use device flow for hosted notebooks and SSH sessions
3. **Production/CI**: Always use API keys for automated environments
4. **API key security**:
- Store in secret managers (GitHub Secrets, AWS Secrets Manager, Vault)
- Never commit to version control
- Rotate regularly
- Use different keys per environment (dev, staging, prod)
5. **Config management**: Keep your `~/.flyte/config.yaml` in source control (without secrets) to maintain consistent settings across your team
=== PAGE: https://www.union.ai/docs/v2/byoc/user-guide/considerations ===
# Considerations
Flyte 2 represents a substantial change from Flyte 1.
While the static graph execution model will soon be available and will mirror Flyte 1 almost exactly, the primary mode of execution in Flyte 2 should remain pure-Python-based.
That is, each Python-based task action has the ability to act as its own engine, kicking off sub-actions, and assembling the outputs, passing them to yet other sub-actions and such.
While this model of execution comes with an enormous amount of flexibility, that flexibility does warrant some caveats to keep in mind when authoring your tasks.
## Non-deterministic behavior
When a task launches another task, a new Action ID is determined.
This ID is a hash of the inputs to the task, the task definition itself, along with some other information.
The fact that this ID is consistently hashed is important when it comes to things like recovery and replay.
For example, assume you have the following tasks
```python
@env.task
async def t1():
val = get_int_input()
await t2(int=val)
@env.task
async def t2(val: int): ...
```
If you run `t1`, and it launches the downstream `t2` task, and then the pod executing `t1` fails, when Flyte restarts `t1` it will automatically detect that `t2` is still running and will just use that.
If `t2` ends up finishing in the interim, those results would just be used.
However, if you introduce non-determinism into the picture, then that guarantee is no longer there.
To give a contrived example:
```python
@env.task
async def t1():
val = get_int_input()
now = datetime.now()
if now.second % 2 == 0:
await t2(int=val)
else:
await t3(int=val)
```
Here, depending on what time it is, either `t2` or `t3` may end up running.
In the earlier scenario, if `t1` crashes unexpectedly, and Flyte retries the execution, a different downstream task may get kicked off instead.
### Dealing with non-determinism
As a developer, the best way to manage non-deterministic behavior (if it is unavoidable) is to be able to observe it and see exactly what is happening in your code. Flyte 2 provides precisely the tool needed to enable this: Traces.
With this feature you decorate the sub-task functions in your code with `@trace`, enabling checkpointing, reproducibility and recovery at a fine-grained level. See **Build tasks > Traces** for more details.
## Type safety
In Flyte 1, the top-level workflow was defined by a Python-like DSL that was compiled into a static DAG composed of tasks, each of which was, internally, defined in real Python.
The system was able to guarantee type safety across task boundaries because the task definitions were static and the inputs and outputs were defined in a way that Flytekit could validate them.
In Flyte 2, the top-level workflow is defined by Python code that runs at runtime (unless using a compiled task).
This means that the system can no longer guarantee type safety at the workflow level.
Happily, the Python ecosystem has evolved considerably since Flyte 1, and Python type hints are now a standard way to define types.
Consequently, in Flyte 2, developers should use Python type hints and type checkers like `mypy` to ensure type safety at all levels, including the top-most task (i.e., the "workflow" level).
## No global state
A core principle of Flyte 2 (that is also shared with Flyte 1) is that you should not try to maintain global state across your workflow.
It will not be translated across tasks containers,
In a single process Python program, global variables are available across functions.
In the distributed execution model of Flyte, each task runs in its own container, and each container is isolated from the others.
If there is some state that needs to be preserved, it must be reconstructable through repeated deterministic execution.
## Driver pod requirements
Tasks don't have to kick off downstream tasks of course and may themselves represent a leaf level atomic unit of compute.
However, when tasks do run other tasks, and more so if they assemble the outputs of those other tasks, then that parent task becomes a driver
pod of sorts.
In Flyte 1, this assembling of intermediate outputs was done by Flyte Propeller.
In 2, it's done by the parent task.
This means that the pod running your parent task must be appropriately sized, and should ideally not be CPU-bound, otherwise it slow down downstream evaluation and kickoff of tasks.
For example, if you had this also scenario,
```python
@env.task
async def t_main():
await t1()
local_cpu_intensive_function()
await t2()
```
The pod running `t_main` will hang in between tasks `t1` and `t2`. Your parent tasks should ideally focus only on orchestration.
## OOM risk from materialized I/O
Something maybe more nuanced to keep in mind is that if you're not using the soon-to-be-released ref mode, outputs are actually
materialized. That is, if you have the following scenario,
```python
@env.task
async def produce_1gb_list() -> List[float]: ...
@env.task
async def t1():
list_floats = produce_1gb_list()
t2(floats=list_floats)
```
The pod running `t1` needs to have memory to handle that 1 GB of floats. Those numbers will be materialized in that pod's memory.
This can lead to out of memory issues.
Note that `flyte.io.File`, `flyte.io.Dir` and `flyte.io.DataFrame` will not suffer from this because while those are materialized, they're only materialized as pointers to offloaded data, so their memory footprint is much lower.
=== PAGE: https://www.union.ai/docs/v2/byoc/tutorials ===
# Tutorials
This section contains tutorials that showcase relevant use cases and provide step-by-step instructions on how to implement various features using Flyte and Union.
### π **Multi-agent trading simulation**
A multi-agent trading simulation, modeling how agents within a firm might interact, strategize, and make trades collaboratively.
### π **Run LLM-generated code**
Securely execute and iterate on LLM-generated code using a code agent with error reflection and retry logic.
### π **Deep research**
Build an agentic workflow for deep research with multi-step reasoning and evaluation.
### π **Hyperparameter optimization**
Run large-scale HPO experiments with zero manual tracking, deterministic results, and automatic recovery.
### π **Automatic prompt engineering**
Easily run prompt optimization with real-time observability, traceability, and automatic recovery.
### π **Text-to-SQL**
Learn how to turn natural language questions into SQL queries with Flyte and LlamaIndex, and explore prompt optimization in practice.
## Subpages
- **Automatic prompt engineering**
- **Deep research**
- **Hyperparameter optimization**
- **Multi-agent trading simulation**
- **Run LLM-generated code**
- **Text-to-SQL**
=== PAGE: https://www.union.ai/docs/v2/byoc/tutorials/auto_prompt_engineering ===
# Automatic prompt engineering
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/auto_prompt_engineering).
When building with LLMs and agents, the first prompt almost never works. We usually need several iterations before results are useful. Doing this manually is slow, inconsistent, and hard to reproduce.
Flyte turns prompt engineering into a systematic process. With Flyte we can:
- Generate candidate prompts automatically.
- Run evaluations in parallel.
- Track results in real time with built-in observability.
- Recover from failures without losing progress.
- Trace the lineage of every experiment for reproducibility.
And we're not limited to prompts. Just like **Automatic prompt engineering > hyperparameter optimization** in ML, we can tune model temperature, retrieval strategies, tool usage, and more. Over time, this grows into full agentic evaluations, tracking not only prompts but also how agents behave, make decisions, and interact with their environment.
In this tutorial, we'll build an automated prompt engineering pipeline with Flyte, step by step.
## Set up the environment
First, let's configure our task environment.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
We need an API key to call GPT-4.1 (our optimization model). Add it as a Flyte secret:
```
flyte create secret openai_api_key
```
We also define CSS styles for live HTML reports that track prompt optimization in real time:

## Prepare the evaluation dataset
Next, we define our golden dataset, a set of prompts with known outputs. This dataset is used to evaluate the quality of generated prompts.
For this tutorial, we use a small geometric shapes dataset. To keep it portable, the data prep task takes a CSV file (as a Flyte `File` or a string for files available remotely) and splits it into train and test subsets.
If you already have prompts and outputs in Google Sheets, simply export them as CSV with two columns: `input` and `target`.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
This approach works with any dataset. You can swap in your own with no extra dependencies.
## Define models
We use two models:
- **Target model** β the one we want to optimize.
- **Review model** β the one that evaluates candidate prompts.
First, we capture all model parameters in a dataclass:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
Then we define a Flyte `trace` to call the model. Unlike a task, a trace runs within the same runtime as the parent process. Since the model is hosted externally, this keeps the call lightweight but still observable.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
You can also host your own models on Union. For example, we deploy gpt-oss-20b using vLLM.
```
import union
from union.app.llm import VLLMApp
from flytekit.extras.accelerators import A10G
Model = union.Artifact(name="gpt-oss-20b")
image = union.ImageSpec(
name="vllm-gpt-oss",
builder="union",
apt_packages=["build-essential", "wget", "gnupg"],
packages=[
"union[vllm]==0.1.191b0",
"--pre vllm==0.10.1+gptoss \
--extra-index-url https://wheels.vllm.ai/gpt-oss/ \
--extra-index-url https://download.pytorch.org/whl/nightly/cu128 \
--index-strategy unsafe-best-match",
],
).with_commands(
[
"wget https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb",
"dpkg -i cuda-keyring_1.1-1_all.deb",
"apt-get update",
"apt-get install -y cuda-toolkit-12-8",
"/usr/local/cuda/bin/nvcc --version",
"chown -R union /root",
"chown -R union /home",
]
)
gpt_oss_app = VLLMApp(
name="gpt-oss-20b-vllm",
model=Model.query(),
model_id="gpt-oss",
container_image=image,
requests=union.Resources(cpu="5", mem="26Gi", gpu="1", ephemeral_storage="150Gi"),
accelerator=A10G,
scaledown_after=300,
stream_model=True,
requires_auth=False,
extra_args="--async-scheduling",
env={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/gpt_oss.py)
We use an A10G GPU instance, and with streaming, you can load model weights directly into GPU memory instead of downloading the weights to disk first, then loading to GPU memory.
To deploy the model, cache the model from HuggingFace with a Union artifact:
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
## Evaluate prompts
We now define the evaluation process.
Each prompt in the dataset is tested in parallel, but we use a semaphore to control concurrency. A helper function ties together the `generate_and_review` task with an HTML report template. Using `asyncio.gather`, we evaluate multiple prompts at once.
The function measures accuracy as the fraction of responses that match the ground truth. Flyte streams these results to the UI, so you can watch evaluations happen live.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
## Optimize prompts
Optimization builds on evaluation. We give the optimizer model:
- the history of prompts tested so far, and
- their accuracies.
The model then proposes a new prompt.
We start with a _baseline_ evaluation using the user-provided prompt. Then for each iteration, the optimizer suggests a new prompt, which we evaluate and log. We continue until we hit the iteration limit.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
At the end, we return the best prompt and its accuracy. The report shows how accuracy improves over time and which prompts were tested.

## Build the full pipeline
The entrypoint task wires everything together:
- Accepts model configs, dataset, iteration count, and concurrency.
- Runs data preparation.
- Calls the optimizer.
- Evaluates both baseline and best prompts on the test set.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
## Run it
We add a simple main block so we can run the workflow as a script:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas==2.3.1",
# "pyarrow==21.0.0",
# "litellm==1.75.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from flyte.io._file import File
env = flyte.TaskEnvironment(
name="auto-prompt-engineering",
image=flyte.Image.from_uv_script(
__file__, name="auto-prompt-engineering", pre=True
),
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
resources=flyte.Resources(cpu=1),
)
CSS = """
"""
# {{/docs-fragment env}}
# {{docs-fragment data_prep}}
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into train/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Train/Test split
df_train = df.iloc[:150].rename(columns={"input": "question", "target": "answer"})
df_test = df.iloc[150:250].rename(columns={"input": "question", "target": "answer"})
return df_train, df_test
# {{/docs-fragment data_prep}}
# {{docs-fragment model_config}}
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
# {{/docs-fragment model_config}}
# {{docs-fragment call_model}}
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
# {{/docs-fragment call_model}}
# {{docs-fragment generate_and_review}}
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
) -> dict:
# Generate response from target model
response = await call_model(
target_model_config,
[
{"role": "system", "content": target_model_config.prompt},
{"role": "user", "content": question},
],
)
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"is_correct": verdict_clean == "true",
}
# {{/docs-fragment generate_and_review}}
async def run_grouped_task(
i,
index,
question,
answer,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
csv_file: File | str = "https://dub.sh/geometric-shapes",
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="Solve the given problem about geometric shapes. Think step by step.",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""You are a review model tasked with evaluating the correctness of a response to a navigation problem.
The response may contain detailed steps and explanations, but the final answer is the key point.
Please determine if the final answer provided in the response is correct based on the ground truth number.
Respond with 'True' if the final answer is correct and 'False' if it is not.
Only respond with 'True' or 'False', nothing else.
Model Response:
{response}
Ground Truth:
{answer}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicate better quality.
{prompt_scores_str}
Each prompt was used together with a problem statement around geometric shapes.
This SVG path element draws a Options: (A) circle (B) heptagon (C) hexagon (D) kite (E) line (F) octagon (G) pentagon (H) rectangle (I) sector (J) triangle
(B)
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't worked in the past
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating system prompts. This means that there should be no placeholders in the prompt, as they cannot be filled at runtime. Instead focus on general instructions that will help the model to solve the task.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 3,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(csv_file, str) and os.path.isfile(csv_file):
csv_file = await File.from_local(csv_file)
df_train, df_test = await data_prep(csv_file)
best_prompt, training_accuracy = await prompt_optimizer(
df_train,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
)
return {
"best_prompt": best_prompt,
"training_accuracy": training_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/auto_prompt_engineering/optimizer.py)
Run it with:
```
uv run --prerelease=allow optimizer.py
```

## Why this matters
Most prompt engineering pipelines start as quick scripts or notebooks. They're fine for experimenting, but they're difficult to scale, reproduce, or debug when things go wrong.
With Flyte 2, we get a more reliable setup:
- Run many evaluations in parallel with [async Python](../../user-guide/flyte-2/async#true-parallelism-for-all-workloads) or [native DSL](../../user-guide/flyte-2/async#the-flytemap-function-familiar-patterns).
- Watch accuracy improve in real time and link results back to the exact dataset, prompt, and model config used.
- Resume cleanly after failures without rerunning everything from scratch.
- Reuse the same pattern to tune other parameters like temperature, retrieval depth, or agent strategies, not just prompts.
## Next steps
You now have a working automated prompt engineering pipeline. Hereβs how you can take it further:
- **Optimize beyond prompts**: Tune temperature, retrieval strategies, or tool usage just like prompts.
- **Expand evaluation metrics**: Add latency, cost, robustness, or diversity alongside accuracy.
- **Move toward agentic evaluation**: Instead of single prompts, test how agents plan, use tools, and recover from failures in long-horizon tasks.
With this foundation, prompt engineering becomes repeatable, observable, and scalable, ready for production-grade LLM and agent systems.
=== PAGE: https://www.union.ai/docs/v2/byoc/tutorials/deep-research ===
# Deep research
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/deep_research_agent); based on work by [Together AI](https://github.com/togethercomputer/open_deep_research).
This example demonstrates how to build an agentic workflow for deep researchβa multi-step reasoning system that mirrors how a human researcher explores, analyzes, and synthesizes information from the web.
Deep research refers to the iterative process of thoroughly investigating a topic: identifying relevant sources, evaluating their usefulness, refining the research direction, and ultimately producing a well-structured summary or report. It's a long-running task that requires the agent to reason over time, adapt its strategy, and chain multiple steps together, making it an ideal fit for an agentic architecture.
In this example, we use:
- [Tavily](https://www.tavily.com/) to search for and retrieve high-quality online resources.
- [LiteLLM](https://litellm.ai/) to route LLM calls that perform reasoning, evaluation, and synthesis.
The agent executes a multi-step trajectory:
- Parallel search across multiple queries.
- Evaluation of retrieved results.
- Adaptive iteration: If results are insufficient, it formulates new research queries and repeats the search-evaluate cycle.
- Synthesis: After a fixed number of iterations, it produces a comprehensive research report.
What makes this workflow compelling is its dynamic, evolving nature. The agent isn't just following a fixed plan; it's making decisions in context, using multiple prompts and reasoning steps to steer the process.
Flyte is uniquely well-suited for this kind of system. It provides:
- Structured composition of dynamic reasoning steps
- Built-in parallelism for faster search and evaluation
- Traceability and observability into each step and iteration
- Scalability for long-running or compute-intensive workloads

Throughout this guide, we'll show how to design this workflow using the Flyte SDK, and how to unlock the full potential of agentic development with tools you already know and trust.
## Setting up the environment
Let's begin by setting up the task environment. We define the following components:
- Secrets for Together and Tavily API keys
- A custom image with required Python packages and apt dependencies (`pandoc`, `texlive-xetex`)
- External YAML file with all LLM prompts baked into the container
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
The Python packages are declared at the top of the file using the `uv` script style:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b6",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# ///
```
## Generate research queries
This task converts a user prompt into a list of focused queries. It makes two LLM calls to generate a high-level research plan and parse that plan into atomic search queries.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
LLM calls use LiteLLM, and each is wrapped with `flyte.trace` for observability:
```
from typing import Any, AsyncIterator, Optional
from litellm import acompletion, completion
import flyte
# {{docs-fragment asingle_shot_llm_call}}
@flyte.trace
async def asingle_shot_llm_call(
model: str,
system_prompt: str,
message: str,
response_format: Optional[dict[str, str | dict[str, Any]]] = None,
max_completion_tokens: int | None = None,
) -> AsyncIterator[str]:
stream = await acompletion(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
],
temperature=0.0,
response_format=response_format,
# NOTE: max_token is deprecated per OpenAI API docs, use max_completion_tokens instead if possible
# NOTE: max_completion_tokens is not currently supported by Together AI, so we use max_tokens instead
max_tokens=max_completion_tokens,
timeout=600,
stream=True,
)
async for chunk in stream:
content = chunk.choices[0].delta.get("content", "")
if content:
yield content
# {{/docs-fragment asingle_shot_llm_call}}
def single_shot_llm_call(
model: str,
system_prompt: str,
message: str,
response_format: Optional[dict[str, str | dict[str, Any]]] = None,
max_completion_tokens: int | None = None,
) -> str:
response = completion(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
],
temperature=0.0,
response_format=response_format,
# NOTE: max_token is deprecated per OpenAI API docs, use max_completion_tokens instead if possible
# NOTE: max_completion_tokens is not currently supported by Together AI, so we use max_tokens instead
max_tokens=max_completion_tokens,
timeout=600,
)
return response.choices[0].message["content"] # type: ignore
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/libs/utils/llms.py)
> [!NOTE]
> We use `flyte.trace` to track intermediate steps within a task, like LLM calls or specific function executions. This lightweight decorator adds observability with minimal overhead and is especially useful for inspecting reasoning chains during task execution.
## Search and summarize
We submit each research query to Tavily and summarize the results using an LLM. We run all summarization tasks with `asyncio.gather`, which signals to Flyte that these tasks can be distributed across separate compute resources.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
## Evaluate research completeness
Now we assess whether the gathered research is sufficient. Again, the task uses two LLM calls to evaluate the completeness of the results and propose additional queries if necessary.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
## Filter results
In this step, we evaluate the relevance of search results and rank them. This task returns the most useful sources for the final synthesis.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
## Generate the final answer
Finally, we generate a detailed research report by synthesizing the top-ranked results. This is the output returned to the user.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
## Orchestration
Next, we define a `research_topic` task to orchestrate the entire deep research workflow. It runs the core stages in sequence: generating research queries, performing search and summarization, evaluating the completeness of results, and producing the final report.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
The `main` task wraps this entire pipeline and adds report generation in HTML format as the final step.
It also serves as the main entry point to the workflow, allowing us to pass in all configuration parameters, including which LLMs to use at each stage.
This flexibility lets us mix and match models for planning, summarization, and final synthesis, helping us optimize for both cost and quality.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pydantic==2.11.5",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# "together==1.5.24",
# "markdown==3.8.2",
# "pymdown-extensions==10.16.1",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment env}}
import asyncio
import json
from pathlib import Path
import flyte
import yaml
from flyte.io._file import File
from libs.utils.data_types import (
DeepResearchResult,
DeepResearchResults,
ResearchPlan,
SourceList,
)
from libs.utils.generation import generate_html, generate_toc_image
from libs.utils.llms import asingle_shot_llm_call
from libs.utils.log import AgentLogger
from libs.utils.tavily_search import atavily_search_results
TIME_LIMIT_MULTIPLIER = 5
MAX_COMPLETION_TOKENS = 4096
logging = AgentLogger("together.open_deep_research")
env = flyte.TaskEnvironment(
name="deep-researcher",
secrets=[
flyte.Secret(key="together_api_key", as_env_var="TOGETHER_API_KEY"),
flyte.Secret(key="tavily_api_key", as_env_var="TAVILY_API_KEY"),
],
image=flyte.Image.from_uv_script(__file__, name="deep-research-agent", pre=True)
.with_apt_packages("pandoc", "texlive-xetex")
.with_source_file(Path("prompts.yaml"), "/root"),
resources=flyte.Resources(cpu=1),
)
# {{/docs-fragment env}}
# {{docs-fragment generate_research_queries}}
@env.task
async def generate_research_queries(
topic: str,
planning_model: str,
json_model: str,
prompts_file: File,
) -> list[str]:
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
PLANNING_PROMPT = prompts["planning_prompt"]
plan = ""
logging.info(f"\n\nGenerated deep research plan for topic: {topic}\n\nPlan:")
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=PLANNING_PROMPT,
message=f"Research Topic: {topic}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
plan += chunk
print(chunk, end="", flush=True)
SEARCH_PROMPT = prompts["plan_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=SEARCH_PROMPT,
message=f"Plan to be parsed: {plan}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
plan = json.loads(response_json)
return plan["queries"]
# {{/docs-fragment generate_research_queries}}
async def _summarize_content_async(
raw_content: str,
query: str,
prompt: str,
summarization_model: str,
) -> str:
"""Summarize content asynchronously using the LLM"""
logging.info("Summarizing content asynchronously using the LLM")
result = ""
async for chunk in asingle_shot_llm_call(
model=summarization_model,
system_prompt=prompt,
message=f"{raw_content}\n\n{query}",
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
result += chunk
return result
# {{docs-fragment search_and_summarize}}
@env.task
async def search_and_summarize(
query: str,
prompts_file: File,
summarization_model: str,
) -> DeepResearchResults:
"""Perform search for a single query"""
if len(query) > 400:
# NOTE: we are truncating the query to 400 characters to avoid Tavily Search issues
query = query[:400]
logging.info(f"Truncated query to 400 characters: {query}")
response = await atavily_search_results(query)
logging.info("Tavily Search Called.")
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
RAW_CONTENT_SUMMARIZER_PROMPT = prompts["raw_content_summarizer_prompt"]
with flyte.group("summarize-content"):
# Create tasks for summarization
summarization_tasks = []
result_info = []
for result in response.results:
if result.raw_content is None:
continue
task = _summarize_content_async(
result.raw_content,
query,
RAW_CONTENT_SUMMARIZER_PROMPT,
summarization_model,
)
summarization_tasks.append(task)
result_info.append(result)
# Use return_exceptions=True to prevent exceptions from propagating
summarized_contents = await asyncio.gather(
*summarization_tasks, return_exceptions=True
)
# Filter out exceptions
summarized_contents = [
result for result in summarized_contents if not isinstance(result, Exception)
]
formatted_results = []
for result, summarized_content in zip(result_info, summarized_contents):
formatted_results.append(
DeepResearchResult(
title=result.title,
link=result.link,
content=result.content,
raw_content=result.raw_content,
filtered_raw_content=summarized_content,
)
)
return DeepResearchResults(results=formatted_results)
# {{/docs-fragment search_and_summarize}}
@env.task
async def search_all_queries(
queries: list[str], summarization_model: str, prompts_file: File
) -> DeepResearchResults:
"""Execute searches for all queries in parallel"""
tasks = []
results_list = []
tasks = [
search_and_summarize(query, prompts_file, summarization_model)
for query in queries
]
if tasks:
res_list = await asyncio.gather(*tasks)
results_list.extend(res_list)
# Combine all results
combined_results = DeepResearchResults(results=[])
for results in results_list:
combined_results = combined_results + results
return combined_results
# {{docs-fragment evaluate_research_completeness}}
@env.task
async def evaluate_research_completeness(
topic: str,
results: DeepResearchResults,
queries: list[str],
prompts_file: File,
planning_model: str,
json_model: str,
) -> list[str]:
"""
Evaluate if the current search results are sufficient or if more research is needed.
Returns an empty list if research is complete, or a list of additional queries if more research is needed.
"""
# Format the search results for the LLM
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
EVALUATION_PROMPT = prompts["evaluation_prompt"]
logging.info("\nEvaluation: ")
evaluation = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=EVALUATION_PROMPT,
message=(
f"{topic}\n\n"
f"{queries}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=None,
):
evaluation += chunk
print(chunk, end="", flush=True)
EVALUATION_PARSING_PROMPT = prompts["evaluation_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=EVALUATION_PARSING_PROMPT,
message=f"Evaluation to be parsed: {evaluation}",
response_format={
"type": "json_object",
"schema": ResearchPlan.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
evaluation = json.loads(response_json)
return evaluation["queries"]
# {{/docs-fragment evaluate_research_completeness}}
# {{docs-fragment filter_results}}
@env.task
async def filter_results(
topic: str,
results: DeepResearchResults,
prompts_file: File,
planning_model: str,
json_model: str,
max_sources: int,
) -> DeepResearchResults:
"""Filter the search results based on the research plan"""
# Format the search results for the LLM, without the raw content
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
FILTER_PROMPT = prompts["filter_prompt"]
logging.info("\nFilter response: ")
filter_response = ""
async for chunk in asingle_shot_llm_call(
model=planning_model,
system_prompt=FILTER_PROMPT,
message=(
f"{topic}\n\n"
f"{formatted_results}"
),
response_format=None,
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
filter_response += chunk
print(chunk, end="", flush=True)
logging.info(f"Filter response: {filter_response}")
FILTER_PARSING_PROMPT = prompts["filter_parsing_prompt"]
response_json = ""
async for chunk in asingle_shot_llm_call(
model=json_model,
system_prompt=FILTER_PARSING_PROMPT,
message=f"Filter response to be parsed: {filter_response}",
response_format={
"type": "json_object",
"schema": SourceList.model_json_schema(),
},
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
response_json += chunk
sources = json.loads(response_json)["sources"]
logging.info(f"Filtered sources: {sources}")
if max_sources != -1:
sources = sources[:max_sources]
# Filter the results based on the source list
filtered_results = [
results.results[i - 1] for i in sources if i - 1 < len(results.results)
]
return DeepResearchResults(results=filtered_results)
# {{/docs-fragment filter_results}}
def _remove_thinking_tags(answer: str) -> str:
"""Remove content within tags"""
while "" in answer and "" in answer:
start = answer.find("")
end = answer.find("") + len("")
answer = answer[:start] + answer[end:]
return answer
# {{docs-fragment generate_research_answer}}
@env.task
async def generate_research_answer(
topic: str,
results: DeepResearchResults,
remove_thinking_tags: bool,
prompts_file: File,
answer_model: str,
) -> str:
"""
Generate a comprehensive answer to the research topic based on the search results.
Returns a detailed response that synthesizes information from all search results.
"""
formatted_results = str(results)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
prompts = yaml.safe_load(yaml_contents)
ANSWER_PROMPT = prompts["answer_prompt"]
answer = ""
async for chunk in asingle_shot_llm_call(
model=answer_model,
system_prompt=ANSWER_PROMPT,
message=f"Research Topic: {topic}\n\nSearch Results:\n{formatted_results}",
response_format=None,
# NOTE: This is the max_token parameter for the LLM call on Together AI,
# may need to be changed for other providers
max_completion_tokens=MAX_COMPLETION_TOKENS,
):
answer += chunk
# this is just to avoid typing complaints
if answer is None or not isinstance(answer, str):
logging.error("No answer generated")
return "No answer generated"
if remove_thinking_tags:
# Remove content within tags
answer = _remove_thinking_tags(answer)
# Remove markdown code block markers if they exist at the beginning
if answer.lstrip().startswith("```"):
# Find the first line break after the opening backticks
first_linebreak = answer.find("\n", answer.find("```"))
if first_linebreak != -1:
# Remove everything up to and including the first line break
answer = answer[first_linebreak + 1 :]
# Remove closing code block if it exists
if answer.rstrip().endswith("```"):
answer = answer.rstrip()[:-3].rstrip()
return answer.strip()
# {{/docs-fragment generate_research_answer}}
# {{docs-fragment research_topic}}
@env.task(retries=flyte.RetryStrategy(count=3, backoff=10, backoff_factor=2))
async def research_topic(
topic: str,
budget: int = 3,
remove_thinking_tags: bool = True,
max_queries: int = 5,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 40,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
prompts_file: File | str = "prompts.yaml",
) -> str:
"""Main method to conduct research on a topic. Will be used for weave evals."""
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
# Step 1: Generate initial queries
queries = await generate_research_queries(
topic=topic,
planning_model=planning_model,
json_model=json_model,
prompts_file=prompts_file,
)
queries = [topic, *queries[: max_queries - 1]]
all_queries = queries.copy()
logging.info(f"Initial queries: {queries}")
if len(queries) == 0:
logging.error("No initial queries generated")
return "No initial queries generated"
# Step 2: Perform initial search
results = await search_all_queries(queries, summarization_model, prompts_file)
logging.info(f"Initial search complete, found {len(results.results)} results")
# Step 3: Conduct iterative research within budget
for iteration in range(budget):
with flyte.group(f"eval_iteration_{iteration}"):
# Evaluate if more research is needed
additional_queries = await evaluate_research_completeness(
topic=topic,
results=results,
queries=all_queries,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
)
# Filter out empty strings and check if any queries remain
additional_queries = [q for q in additional_queries if q]
if not additional_queries:
logging.info("No need for additional research")
break
# for debugging purposes we limit the number of queries
additional_queries = additional_queries[:max_queries]
logging.info(f"Additional queries: {additional_queries}")
# Expand research with new queries
new_results = await search_all_queries(
additional_queries, summarization_model, prompts_file
)
logging.info(
f"Follow-up search complete, found {len(new_results.results)} results"
)
results = results + new_results
all_queries.extend(additional_queries)
# Step 4: Generate final answer
logging.info(f"Generating final answer for topic: {topic}")
results = results.dedup()
logging.info(f"Deduplication complete, kept {len(results.results)} results")
filtered_results = await filter_results(
topic=topic,
results=results,
prompts_file=prompts_file,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
)
logging.info(
f"LLM Filtering complete, kept {len(filtered_results.results)} results"
)
# Generate final answer
answer = await generate_research_answer(
topic=topic,
results=filtered_results,
remove_thinking_tags=remove_thinking_tags,
prompts_file=prompts_file,
answer_model=answer_model,
)
return answer
# {{/docs-fragment research_topic}}
# {{docs-fragment main}}
@env.task(report=True)
async def main(
topic: str = (
"List the essential requirements for a developer-focused agent orchestration system."
),
prompts_file: File | str = "/root/prompts.yaml",
budget: int = 2,
remove_thinking_tags: bool = True,
max_queries: int = 3,
answer_model: str = "together_ai/deepseek-ai/DeepSeek-V3",
planning_model: str = "together_ai/Qwen/Qwen2.5-72B-Instruct-Turbo",
json_model: str = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
max_sources: int = 10,
summarization_model: str = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
) -> str:
if isinstance(prompts_file, str):
prompts_file = await File.from_local(prompts_file)
answer = await research_topic(
topic=topic,
budget=budget,
remove_thinking_tags=remove_thinking_tags,
max_queries=max_queries,
answer_model=answer_model,
planning_model=planning_model,
json_model=json_model,
max_sources=max_sources,
summarization_model=summarization_model,
prompts_file=prompts_file,
)
async with prompts_file.open() as fh:
data = await fh.read()
yaml_contents = str(data, "utf-8")
toc_image_url = await generate_toc_image(
yaml.safe_load(yaml_contents)["data_visualization_prompt"],
planning_model,
topic,
)
html_content = await generate_html(answer, toc_image_url)
await flyte.report.replace.aio(html_content, do_flush=True)
await flyte.report.flush.aio()
return html_content
# {{/docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/agent.py)
## Run the deep research agent
First, create the required secrets:
```
flyte create secret TOGETHER_API_KEY <>
flyte create secret TAVILY_API_KEY <>
```
Run the agent:
```
uv run --prerelease=allow agent.py
```
If you want to test it locally first, run the following commands:
```
brew install pandoc
brew install basictex # restart your terminal after install
export TOGETHER_API_KEY=<>
export TAVILY_API_KEY=<>
uv run --prerelease=allow agent.py
```
## Evaluate with Weights & Biases Weave
We use W&B Weave to evaluate the full agent pipeline and analyze LLM-generated responses. The evaluation runs as a Flyte pipeline and uses an LLM-as-a-judge scorer to measure the quality of LLM-generated responses.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "weave==0.51.51",
# "datasets==3.6.0",
# "huggingface-hub==0.32.6",
# "litellm==1.72.2",
# "tavily-python==0.7.5",
# ]
# ///
import os
import weave
from agent import research_topic
from datasets import load_dataset
from huggingface_hub import login
from libs.utils.log import AgentLogger
from litellm import completion
import flyte
logging = AgentLogger()
weave.init(project_name="deep-researcher")
env = flyte.TaskEnvironment(name="deep-researcher-eval")
@weave.op
def llm_as_a_judge_scoring(answer: str, output: str, question: str) -> bool:
prompt = f"""
Given the following question and answer, evaluate the answer against the correct answer:
{question}
{output}
{answer}
Note that the agent answer might be a long text containing a lot of information or it might be a short answer.
You should read the entire text and think if the agent answers the question somewhere
in the text. You should try to be flexible with the answer but careful.
For example, answering with names instead of name and surname is fine.
The important thing is that the answer of the agent either contains the correct answer or is equal to
the correct answer.
The agent answer is correct because I can read that ....
1
Otherwise, return
The agent answer is incorrect because there is ...
0
"""
messages = [
{
"role": "system",
"content": "You are an helpful assistant that returns a number between 0 and 1.",
},
{"role": "user", "content": prompt},
]
answer = (
completion(
model="together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
messages=messages,
max_tokens=1000,
temperature=0.0,
)
.choices[0] # type: ignore
.message["content"] # type: ignore
)
return bool(int(answer.split("")[1].split("")[0].strip()))
def authenticate_huggingface():
"""Authenticate with Hugging Face Hub using token from environment variable."""
token = os.getenv("HUGGINGFACE_TOKEN")
if not token:
raise ValueError(
"HUGGINGFACE_TOKEN environment variable not set. "
"Please set it with your token from https://huggingface.co/settings/tokens"
)
try:
login(token=token)
print("Successfully authenticated with Hugging Face Hub")
except Exception as e:
raise RuntimeError(f"Failed to authenticate with Hugging Face Hub: {e!s}")
@env.task
async def load_questions(
dataset_names: list[str] | None = None,
) -> list[dict[str, str]]:
"""
Load questions from the specified Hugging Face dataset configurations.
Args:
dataset_names: List of dataset configurations to load
Options:
"smolagents:simpleqa",
"hotpotqa",
"simpleqa",
"together-search-bench"
If None, all available configurations except hotpotqa will be loaded
Returns:
List of question-answer pairs
"""
if dataset_names is None:
dataset_names = ["smolagents:simpleqa"]
all_questions = []
# Authenticate with Hugging Face Hub (once and for all)
authenticate_huggingface()
for dataset_name in dataset_names:
print(f"Loading dataset: {dataset_name}")
try:
if dataset_name == "together-search-bench":
# Load Together-Search-Bench dataset
dataset_path = "togethercomputer/together-search-bench"
ds = load_dataset(dataset_path)
if "test" in ds:
split_data = ds["test"]
else:
print(f"No 'test' split found in dataset at {dataset_path}")
continue
for i in range(len(split_data)):
item = split_data[i]
question_data = {
"question": item["question"],
"answer": item["answer"],
"dataset": item.get("dataset", "together-search-bench"),
}
all_questions.append(question_data)
print(f"Loaded {len(split_data)} questions from together-search-bench dataset")
continue
elif dataset_name == "hotpotqa":
# Load HotpotQA dataset (using distractor version for validation)
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", trust_remote_code=True)
split_name = "validation"
elif dataset_name == "simpleqa":
ds = load_dataset("basicv8vc/SimpleQA")
split_name = "test"
else:
# Strip "smolagents:" prefix when loading the dataset
actual_dataset = dataset_name.split(":")[-1]
ds = load_dataset("smolagents/benchmark-v1", actual_dataset)
split_name = "test"
except Exception as e:
print(f"Failed to load dataset {dataset_name}: {e!s}")
continue # Skip this dataset if it fails to load
print(f"Dataset structure for {dataset_name}: {ds}")
print(f"Available splits: {list(ds)}")
split_data = ds[split_name] # type: ignore
for i in range(len(split_data)):
item = split_data[i]
if dataset_name == "hotpotqa":
# we remove questions that are easy or medium (if any) just to reduce the number of questions
if item["level"] != "hard":
continue
question_data = {
"question": item["question"],
"answer": item["answer"],
"dataset": dataset_name,
}
elif dataset_name == "simpleqa":
# Handle SimpleQA dataset format
question_data = {
"question": item["problem"],
"answer": item["answer"],
"dataset": dataset_name,
}
else:
question_data = {
"question": item["question"],
"answer": item["true_answer"],
"dataset": dataset_name,
}
all_questions.append(question_data)
print(f"Loaded {len(all_questions)} questions in total")
return all_questions
@weave.op
async def predict(question: str):
return await research_topic(topic=str(question))
@env.task
async def main(datasets: list[str] = ["together-search-bench"], limit: int | None = 1):
questions = await load_questions(datasets)
if limit is not None:
questions = questions[:limit]
print(f"Limited to {len(questions)} question(s)")
evaluation = weave.Evaluation(dataset=questions, scorers=[llm_as_a_judge_scoring])
await evaluation.evaluate(predict)
if __name__ == "__main__":
flyte.init_from_config()
flyte.with_runcontext(raw_data_path="data").run(main)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/deep_research_agent/weave_evals.py)
You can run this pipeline locally as follows:
```
export HUGGINGFACE_TOKEN=<> # https://huggingface.co/settings/tokens
export WANDB_API_KEY=<> # https://wandb.ai/settings
uv run --prerelease=allow weave_evals.py
```
The script will run all tasks in the pipeline and log the evaluation results to Weights & Biases.
While you can also evaluate individual tasks, this script focuses on end-to-end evaluation of the end-to-end deep research workflow.

=== PAGE: https://www.union.ai/docs/v2/byoc/tutorials/hpo ===
# Hyperparameter optimization
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/ml/optimizer.py).
Hyperparameter Optimization (HPO) is a critical step in the machine learning (ML) lifecycle. Hyperparameters are the knobs and dials of a modelβvalues such as learning rates, tree depths, or dropout rates that significantly impact performance but cannot be learned during training. Instead, we must select them manually or optimize them through guided search.
Model developers often enjoy the flexibility of choosing from a wide variety of model types, whether gradient boosted machines (GBMs), generalized linear models (GLMs), deep learning architectures, or dozens of others. A common challenge across all these options is the need to systematically explore model performance across hyperparameter configurations tailored to the specific dataset and task.
Thankfully, this exploration can be automated. Frameworks like [Optuna](https://optuna.org/), [Hyperopt](https://hyperopt.github.io/hyperopt/), and [Ray Tune](https://docs.ray.io/en/latest/tune/index.html) use advanced sampling algorithms to efficiently search the hyperparameter space and identify optimal configurations. HPO may be executed in two distinct ways:
- **Serial HPO** runs one trial at a time, which is easy to set up but can be painfully slow.
- **Parallel HPO** distributes trials across multiple processes. It typically follows a pattern with two parameters: **_N_**, the total number of trials to run, and **_C_**, the maximum number of trials that can run concurrently. Trials are executed asynchronously, and new ones are scheduled based on the results and status of completed or in-progress ones.
However, parallel HPO introduces a new complexity: the need for a centralized state that tracks:
- All past trials (successes and failures)
- All ongoing trials
This state is essential so that the optimization algorithm can make informed decisions about which hyperparameters to try next.
## A better way to run HPO
This is where Flyte shines.
- There's no need to manage a separate centralized database for state tracking, as every objective run is **cached**, **recorded**, and **recoverable** via Flyte's execution engine.
- The entire HPO process is observable in the UI with full lineage and metadata for each trial.
- Each objective is seeded for reproducibility, enabling deterministic trial results.
- If the main optimization task crashes or is terminated, **Flyte can resume from the last successful or failed trial, making the experiment highly fault-tolerant**.
- Trial functions can be strongly typed, enabling rich, flexible hyperparameter spaces while maintaining strict type safety across trials.
In this example, we combine Flyte with Optuna to optimize a `RandomForestClassifier` on the Iris dataset. Each trial runs in an isolated task, and the optimization process is orchestrated asynchronously, with Flyte handling the underlying scheduling, retries, and caching.
## Declare dependencies
We start by declaring a Python environment using Python 3.13 and specifying our runtime dependencies.
```
# /// script
requires-python = "==3.13"
dependencies = [
"optuna>=4.0.0,<5.0.0",
"flyte>=2.0.0b0",
"scikit-learn==1.7.0",
]
# ///
```
With the environment defined, we begin by importing standard library and third-party modules necessary for both the ML task and distributed execution.
```
import asyncio
import typing
from collections import Counter
from typing import Optional, Union
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
These standard library imports are essential for asynchronous execution (`asyncio`), type annotations (`typing`, `Optional`, `Union`), and aggregating trial state counts (`Counter`).
```
import optuna
from optuna import Trial
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.utils import shuffle
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
We use Optuna for hyperparameter optimization and several utilities from scikit-learn to prepare data (`load_iris`), define the model (`RandomForestClassifier`), evaluate it (`cross_val_score`), and shuffle the dataset for randomness (`shuffle`).
```
import flyte
import flyte.errors
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
Flyte is our orchestration framework. We use it to define tasks, manage resources, and recover from execution errors.
## Define the task environment
We define a Flyte task environment called `driver`, which encapsulates metadata, compute resources, the container image context needed for remote execution, and caching behavior.
```
driver = flyte.TaskEnvironment(
name="driver",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="optimizer"),
cache="auto",
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
This environment specifies that the tasks will run with 1 CPU and 250Mi of memory, the image is built using the current script (`__file__`), and caching is enabled.
You can configure the Flyte task environment to reuse containers across multiple executions by setting the
reusable field to
flyte.ReusePolicy(replicas=..., idle_ttl=...). This is especially useful when the final objective
computations are short-lived, as it avoids unnecessary container spin-up costs. Learn more about reusable containers
here.
## Define the optimizer
Next, we define an `Optimizer` class that handles parallel execution of Optuna trials using async coroutines. This class abstracts the full optimization loop and supports concurrent trial execution with live logging.
```
class Optimizer:
def __init__(
self,
objective: callable,
n_trials: int,
concurrency: int = 1,
delay: float = 0.1,
study: Optional[optuna.Study] = None,
log_delay: float = 0.1,
):
self.n_trials: int = n_trials
self.concurrency: int = concurrency
self.objective: typing.Callable = objective
self.delay: float = delay
self.log_delay = log_delay
self.study = study if study else optuna.create_study()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
We pass the `objective` function, number of trials to run (`n_trials`), and maximum parallel trials (`concurrency`). The optional delay throttles execution between trials, while `log_delay` controls how often logging runs. If no existing Optuna Study is provided, a new one is created automatically.
```
async def log(self):
while True:
await asyncio.sleep(self.log_delay)
counter = Counter()
for trial in self.study.trials:
counter[trial.state.name.lower()] += 1
counts = dict(counter, queued=self.n_trials - len(self))
# print items in dictionary in a readable format
formatted = [f"{name}: {count}" for name, count in counts.items()]
print(f"{' '.join(formatted)}")
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
This method periodically prints the number of trials in each state (e.g., running, complete, fail). It keeps users informed of ongoing optimization progress and is invoked as a background task when logging is enabled.

_Logs are streamed live as the execution progresses._
```
async def spawn(self, semaphore: asyncio.Semaphore):
async with semaphore:
trial: Trial = self.study.ask()
try:
print("Starting trial", trial.number)
params = {
"n_estimators": trial.suggest_int("n_estimators", 10, 200),
"max_depth": trial.suggest_int("max_depth", 2, 20),
"min_samples_split": trial.suggest_float(
"min_samples_split", 0.1, 1.0
),
}
output = await self.objective(params)
self.study.tell(trial, output, state=optuna.trial.TrialState.COMPLETE)
except flyte.errors.RuntimeUserError as e:
print(f"Trial {trial.number} failed: {e}")
self.study.tell(trial, state=optuna.trial.TrialState.FAIL)
await asyncio.sleep(self.delay)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
Each call to `spawn` runs a single Optuna trial. The `semaphore` ensures that only a fixed number of concurrent trials are active at once, respecting the `concurrency` parameter. We first ask Optuna for a new trial and generate a parameter dictionary by querying the trial object for suggested hyperparameters. The trial is then evaluated by the objective function. If successful, we mark it as `COMPLETE`. If the trial fails due to a `RuntimeUserError` from Flyte, we log and record the failure in the Optuna study.
```
async def __call__(self):
# create semaphore to manage concurrency
semaphore = asyncio.Semaphore(self.concurrency)
# create list of async trials
trials = [self.spawn(semaphore) for _ in range(self.n_trials)]
logger: Optional[asyncio.Task] = None
if self.log_delay:
logger = asyncio.create_task(self.log())
# await all trials to complete
await asyncio.gather(*trials)
if self.log_delay and logger:
logger.cancel()
try:
await logger
except asyncio.CancelledError:
pass
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
The `__call__` method defines the overall async optimization routine. It creates the semaphore, spawns `n_trials` coroutines, and optionally starts the background logging task. All trials are awaited with `asyncio.gather`.
```
def __len__(self) -> int:
"""Return the number of trials in history."""
return len(self.study.trials)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
This method simply allows us to query the number of trials already associated with the study.
## Define the objective function
The objective task defines how we evaluate a particular set of hyperparameters. It's an async task, allowing for caching, tracking, and recoverability across executions.
```
@driver.task
async def objective(params: dict[str, Union[int, float]]) -> float:
data = load_iris()
X, y = shuffle(data.data, data.target, random_state=42)
clf = RandomForestClassifier(
n_estimators=params["n_estimators"],
max_depth=params["max_depth"],
min_samples_split=params["min_samples_split"],
random_state=42,
n_jobs=-1,
)
# Use cross-validation to evaluate performance
score = cross_val_score(clf, X, y, cv=3, scoring="accuracy").mean()
return score.item()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
We use the Iris dataset as a toy classification problem. The input params dictionary contains the trial's hyperparameters, which we unpack into a `RandomForestClassifier`. We shuffle the dataset for randomness, and compute a 3-fold cross-validation accuracy.
## Define the main optimization loop
The optimize task is the main driver of our optimization experiment. It creates the `Optimizer` instance and invokes it.
```
@driver.task
async def optimize(
n_trials: int = 20,
concurrency: int = 5,
delay: float = 0.05,
log_delay: float = 0.1,
) -> dict[str, Union[int, float]]:
optimizer = Optimizer(
objective=objective,
n_trials=n_trials,
concurrency=concurrency,
delay=delay,
log_delay=log_delay,
study=optuna.create_study(
direction="maximize", sampler=optuna.samplers.TPESampler(seed=42)
),
)
await optimizer()
best = optimizer.study.best_trial
print("β Best Trial")
print(" Number :", best.number)
print(" Params :", best.params)
print(" Score :", best.value)
return best.params
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
We configure a `TPESampler` for Optuna and `seed` it for determinism. After running all trials, we extract the best-performing trial and print its parameters and score. Returning the best params allows downstream tasks or clients to use the tuned model.
## Run the experiment
Finally, we include an executable entry point to run this optimization using `flyte.run`.
```
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(optimize, 100, 10)
print(run.url)
run.wait()
# {{//docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/ml/optimizer.py)
We load Flyte config from `config.yaml`, launch the optimize task with 100 trials and concurrency of 10, and print a link to view the execution in the Flyte UI.

_Each objective run is cached, recorded, and recoverable. With concurrency set to 10, only 10 trials execute in parallel at any given time._
=== PAGE: https://www.union.ai/docs/v2/byoc/tutorials/trading-agents ===
# Multi-agent trading simulation
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/trading_agents); based on work by [TauricResearch](https://github.com/TauricResearch/TradingAgents).
This example walks you through building a multi-agent trading simulation, modeling how agents within a firm might interact, strategize, and make trades collaboratively.

_Trading agents execution visualization_
## TL;DR
- You'll build a trading firm made up of agents that analyze, argue, and act, modeled with Python functions.
- You'll use the Flyte SDK to orchestrate this world β giving you visibility, retries, caching, and durability.
- You'll learn how to plug in tools, structure conversations, and track decisions across agents.
- You'll see how agents debate, use context, generate reports, and retain memory via vector DBs.
## What is an agent, anyway?
Agentic workflows are a rising pattern for complex problem-solving with LLMs. Think of agents as:
- An LLM (like GPT-4 or Mistral)
- A loop that keeps them thinking until a goal is met
- A set of optional tools they can call (APIs, search, calculators, etc.)
- Enough tokens to reason about the problem at hand
That's it.
You define tools, bind them to an agent, and let it run, reasoning step-by-step, optionally using those tools, until it finishes.
## What's different here?
We're not building yet another agent framework. You're free to use LangChain, custom code, or whatever setup you like.
What we're giving you is the missing piece: a way to run these workflows **reliably, observably, and at scale, with zero rewrites.**
With Flyte, you get:
- Prompt + tool traceability and full state retention
- Built-in retries, caching, and failure recovery
- A native way to plug in your agents; no magic syntax required
## How it works: step-by-step walkthrough
This simulation is powered by a Flyte task that orchestrates multiple intelligent agents working together to analyze a company's stock and make informed trading decisions.

_Trading agents schema_
### Entry point
Everything begins with a top-level Flyte task called `main`, which serves as the entry point to the workflow.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "akshare==1.16.98",
# "backtrader==1.9.78.123",
# "boto3==1.39.9",
# "chainlit==2.5.5",
# "eodhd==1.0.32",
# "feedparser==6.0.11",
# "finnhub-python==2.4.23",
# "langchain-experimental==0.3.4",
# "langchain-openai==0.3.23",
# "pandas==2.3.0",
# "parsel==1.10.0",
# "praw==7.8.1",
# "pytz==2025.2",
# "questionary==2.1.0",
# "redis==6.2.0",
# "requests==2.32.4",
# "stockstats==0.6.5",
# "tqdm==4.67.1",
# "tushare==1.4.21",
# "typing-extensions==4.14.0",
# "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy
import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
create_neutral_debator,
create_risky_debator,
create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
reflect_bear_researcher,
reflect_bull_researcher,
reflect_research_manager,
reflect_risk_manager,
reflect_trader,
)
@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
"""Process a full trading signal to extract the core decision."""
messages = [
{
"role": "system",
"content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
},
{"role": "human", "content": full_signal},
]
return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content
async def run_analyst(analyst_name, state, online_tools):
# Create a copy of the state for isolation
run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")
# Run the analyst's chain
result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)
# Determine the report key
report_key = (
"sentiment_report"
if analyst_name == "social_media"
else f"{analyst_name}_report"
)
report_value = getattr(result_state, report_key)
return result_state.messages[1:], report_key, report_value
# {{docs-fragment main}}
@env.task
async def main(
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
if not selected_analysts:
raise ValueError(
"No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
)
state = AgentState(
messages=[{"role": "human", "content": company_name}],
company_of_interest=company_name,
trade_date=str(trade_date),
)
# Run all analysts concurrently
results = await asyncio.gather(
*[
run_analyst(analyst, deepcopy(state), online_tools)
for analyst in selected_analysts
]
)
# Flatten and append all resulting messages into the shared state
for messages, report_attr, report in results:
state.messages.extend(messages)
setattr(state, report_attr, report)
# Bull/Bear debate loop
state = await create_bull_researcher(QUICK_THINKING_LLM, state) # Start with bull
while state.investment_debate_state.count < 2 * max_debate_rounds:
current = state.investment_debate_state.current_response
if current.startswith("Bull"):
state = await create_bear_researcher(QUICK_THINKING_LLM, state)
else:
state = await create_bull_researcher(QUICK_THINKING_LLM, state)
state = await create_research_manager(DEEP_THINKING_LLM, state)
state = await create_trader(QUICK_THINKING_LLM, state)
# Risk debate loop
state = await create_risky_debator(QUICK_THINKING_LLM, state) # Start with risky
while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
speaker = state.risk_debate_state.latest_speaker
if speaker == "Risky":
state = await create_safe_debator(QUICK_THINKING_LLM, state)
elif speaker == "Safe":
state = await create_neutral_debator(QUICK_THINKING_LLM, state)
else:
state = await create_risky_debator(QUICK_THINKING_LLM, state)
state = await create_risk_manager(DEEP_THINKING_LLM, state)
decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)
return decision, state
# {{/docs-fragment main}}
# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
await asyncio.gather(
reflect_bear_researcher(state, returns),
reflect_bull_researcher(state, returns),
reflect_trader(state, returns),
reflect_risk_manager(state, returns),
reflect_research_manager(state, returns),
)
return "Reflection completed."
# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
returns: str,
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> str:
_, state = await main(
selected_analysts,
max_debate_rounds,
max_risk_discuss_rounds,
online_tools,
company_name,
trade_date,
)
return await reflect_and_store(state, returns)
# {{/docs-fragment reflect_on_decisions}}
# {{docs-fragment execute_main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
# print(run.url)
# {{/docs-fragment execute_main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py)
This task accepts several inputs:
- the list of analysts to run,
- the number of debate and risk discussion rounds,
- a flag to enable online tools,
- the company you're evaluating,
- and the target trading date.
The most interesting parameter here is the list of analysts to run. It determines which analyst agents will be invoked and shapes the overall structure of the simulation. Based on this input, the task dynamically launches agent tasks, running them in parallel.
The `main` task is written as a regular asynchronous Python function wrapped with Flyte's task decorator. No domain-specific language or orchestration glue is needed β just idiomatic Python, optionally using async for better performance. The task environment is configured once and shared across all tasks for consistency.
```
# {{docs-fragment env}}
import flyte
QUICK_THINKING_LLM = "gpt-4o-mini"
DEEP_THINKING_LLM = "o4-mini"
env = flyte.TaskEnvironment(
name="trading-agents",
secrets=[
flyte.Secret(key="finnhub_api_key", as_env_var="FINNHUB_API_KEY"),
flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY"),
],
image=flyte.Image.from_uv_script("main.py", name="trading-agents", pre=True),
resources=flyte.Resources(cpu="1"),
cache="auto",
)
# {{/docs-fragment env}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/flyte_env.py)
### Analyst agents
Each analyst agent comes equipped with a set of tools and a carefully designed prompt tailored to its specific domain. These tools are modular Flyte tasks β for example, downloading financial reports or computing technical indicators β and benefit from Flyte's built-in caching to avoid redundant computation.
```
from datetime import datetime
import pandas as pd
import tools.interface as interface
import yfinance as yf
from flyte_env import env
from flyte.io import File
@env.task
async def get_reddit_news(
curr_date: str, # Date you want to get news for in yyyy-mm-dd format
) -> str:
"""
Retrieve global news from Reddit within a specified time frame.
Args:
curr_date (str): Date you want to get news for in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the latest global news
from Reddit in the specified time frame.
"""
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
return global_news_result
@env.task
async def get_finnhub_news(
ticker: str, # Search query of a company, e.g. 'AAPL, TSM, etc.
start_date: str, # Start date in yyyy-mm-dd format
end_date: str, # End date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest news about a given stock from Finnhub within a date range
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing news about the company
within the date range from start_date to end_date
"""
end_date_str = end_date
end_date = datetime.strptime(end_date, "%Y-%m-%d")
start_date = datetime.strptime(start_date, "%Y-%m-%d")
look_back_days = (end_date - start_date).days
finnhub_news_result = interface.get_finnhub_news(
ticker, end_date_str, look_back_days
)
return finnhub_news_result
@env.task
async def get_reddit_stock_info(
ticker: str, # Ticker of a company. e.g. AAPL, TSM
curr_date: str, # Current date you want to get news for
) -> str:
"""
Retrieve the latest news about a given stock from Reddit, given the current date.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): current date in yyyy-mm-dd format to get news for
Returns:
str: A formatted dataframe containing the latest news about the company on the given date
"""
stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)
return stock_news_results
@env.task
async def get_YFin_data(
symbol: str, # ticker symbol of the company
start_date: str, # Start date in yyyy-mm-dd format
end_date: str, # End date in yyyy-mm-dd format
) -> str:
"""
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the stock price data
for the specified ticker symbol in the specified date range.
"""
result_data = interface.get_YFin_data(symbol, start_date, end_date)
return result_data
@env.task
async def get_YFin_data_online(
symbol: str, # ticker symbol of the company
start_date: str, # Start date in yyyy-mm-dd format
end_date: str, # End date in yyyy-mm-dd format
) -> str:
"""
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the stock price data
for the specified ticker symbol in the specified date range.
"""
result_data = interface.get_YFin_data_online(symbol, start_date, end_date)
return result_data
@env.task
async def cache_market_data(symbol: str, start_date: str, end_date: str) -> File:
data_file = f"{symbol}-YFin-data-{start_date}-{end_date}.csv"
data = yf.download(
symbol,
start=start_date,
end=end_date,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
return await File.from_local(data_file)
@env.task
async def get_stockstats_indicators_report(
symbol: str, # ticker symbol of the company
indicator: str, # technical indicator to get the analysis and report of
curr_date: str, # The current trading date you are trading on, YYYY-mm-dd
look_back_days: int = 30, # how many days to look back
) -> str:
"""
Retrieve stock stats indicators for a given ticker symbol and indicator.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
indicator (str): Technical indicator to get the analysis and report of
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
look_back_days (int): How many days to look back, default is 30
Returns:
str: A formatted dataframe containing the stock stats indicators
for the specified ticker symbol and indicator.
"""
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date = start_date.strftime("%Y-%m-%d")
end_date = end_date.strftime("%Y-%m-%d")
data_file = await cache_market_data(symbol, start_date, end_date)
local_data_file = await data_file.download()
result_stockstats = interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, False, local_data_file
)
return result_stockstats
# {{docs-fragment get_stockstats_indicators_report_online}}
@env.task
async def get_stockstats_indicators_report_online(
symbol: str, # ticker symbol of the company
indicator: str, # technical indicator to get the analysis and report of
curr_date: str, # The current trading date you are trading on, YYYY-mm-dd"
look_back_days: int = 30, # "how many days to look back"
) -> str:
"""
Retrieve stock stats indicators for a given ticker symbol and indicator.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
indicator (str): Technical indicator to get the analysis and report of
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
look_back_days (int): How many days to look back, default is 30
Returns:
str: A formatted dataframe containing the stock stats indicators
for the specified ticker symbol and indicator.
"""
today_date = pd.Timestamp.today()
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date = start_date.strftime("%Y-%m-%d")
end_date = end_date.strftime("%Y-%m-%d")
data_file = await cache_market_data(symbol, start_date, end_date)
local_data_file = await data_file.download()
result_stockstats = interface.get_stock_stats_indicators_window(
symbol, indicator, curr_date, look_back_days, True, local_data_file
)
return result_stockstats
# {{/docs-fragment get_stockstats_indicators_report_online}}
@env.task
async def get_finnhub_company_insider_sentiment(
ticker: str, # ticker symbol for the company
curr_date: str, # current date of you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve insider sentiment information about a company (retrieved
from public SEC information) for the past 30 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the sentiment in the past 30 days starting at curr_date
"""
data_sentiment = interface.get_finnhub_company_insider_sentiment(
ticker, curr_date, 30
)
return data_sentiment
@env.task
async def get_finnhub_company_insider_transactions(
ticker: str, # ticker symbol
curr_date: str, # current date you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve insider transaction information about a company
(retrieved from public SEC information) for the past 30 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's insider transactions/trading information in the past 30 days
"""
data_trans = interface.get_finnhub_company_insider_transactions(
ticker, curr_date, 30
)
return data_trans
@env.task
async def get_simfin_balance_sheet(
ticker: str, # ticker symbol
freq: str, # reporting frequency of the company's financial history: annual/quarterly
curr_date: str, # current date you are trading at, yyyy-mm-dd
):
"""
Retrieve the most recent balance sheet of a company
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency of the company's financial history: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's most recent balance sheet
"""
data_balance_sheet = interface.get_simfin_balance_sheet(ticker, freq, curr_date)
return data_balance_sheet
@env.task
async def get_simfin_cashflow(
ticker: str, # ticker symbol
freq: str, # reporting frequency of the company's financial history: annual/quarterly
curr_date: str, # current date you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve the most recent cash flow statement of a company
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency of the company's financial history: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's most recent cash flow statement
"""
data_cashflow = interface.get_simfin_cashflow(ticker, freq, curr_date)
return data_cashflow
@env.task
async def get_simfin_income_stmt(
ticker: str, # ticker symbol
freq: str, # reporting frequency of the company's financial history: annual/quarterly
curr_date: str, # current date you are trading at, yyyy-mm-dd
) -> str:
"""
Retrieve the most recent income statement of a company
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency of the company's financial history: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's most recent income statement
"""
data_income_stmt = interface.get_simfin_income_statements(ticker, freq, curr_date)
return data_income_stmt
@env.task
async def get_google_news(
query: str, # Query to search with
curr_date: str, # Curr date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest news from Google News based on a query and date range.
Args:
query (str): Query to search with
curr_date (str): Current date in yyyy-mm-dd format
look_back_days (int): How many days to look back
Returns:
str: A formatted string containing the latest news from Google News
based on the query and date range.
"""
google_news_results = interface.get_google_news(query, curr_date, 7)
return google_news_results
@env.task
async def get_stock_news_openai(
ticker: str, # the company's ticker
curr_date: str, # Current date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest news about a given stock by using OpenAI's news API.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest news about the company on the given date.
"""
openai_news_results = interface.get_stock_news_openai(ticker, curr_date)
return openai_news_results
@env.task
async def get_global_news_openai(
curr_date: str, # Current date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest macroeconomics news on a given date using OpenAI's macroeconomics news API.
Args:
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest macroeconomic news on the given date.
"""
openai_news_results = interface.get_global_news_openai(curr_date)
return openai_news_results
@env.task
async def get_fundamentals_openai(
ticker: str, # the company's ticker
curr_date: str, # Current date in yyyy-mm-dd format
) -> str:
"""
Retrieve the latest fundamental information about a given stock
on a given date by using OpenAI's news API.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A formatted string containing the latest fundamental information
about the company on the given date.
"""
openai_fundamentals_results = interface.get_fundamentals_openai(ticker, curr_date)
return openai_fundamentals_results
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/tools/toolkit.py)
When initialized, an analyst enters a structured reasoning loop (via LangChain), where it can call tools, observe outputs, and refine its internal state before generating a final report. These reports are later consumed by downstream agents.
Here's an example of a news analyst that interprets global events and macroeconomic signals. We specify the tools accessible to the analyst, and the LLM selects which ones to use based on context.
```
import asyncio
from agents.utils.utils import AgentState
from flyte_env import env
from langchain_core.messages import ToolMessage, convert_to_openai_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from tools import toolkit
import flyte
MAX_ITERATIONS = 5
# {{docs-fragment agent_helper}}
async def run_chain_with_tools(
type: str, state: AgentState, llm: str, system_message: str, tool_names: list[str]
) -> AgentState:
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
" For your reference, the current date is {current_date}. The company we want to look at is {ticker}.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join(tool_names))
prompt = prompt.partial(current_date=state.trade_date)
prompt = prompt.partial(ticker=state.company_of_interest)
chain = prompt | ChatOpenAI(model=llm).bind_tools(
[getattr(toolkit, tool_name).func for tool_name in tool_names]
)
iteration = 0
while iteration < MAX_ITERATIONS:
result = await chain.ainvoke(state.messages)
state.messages.append(convert_to_openai_messages(result))
if not result.tool_calls:
# Final response β no tools required
setattr(state, f"{type}_report", result.content or "")
break
# Run all tool calls in parallel
async def run_single_tool(tool_call):
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool = getattr(toolkit, tool_name, None)
if not tool:
return None
content = await tool(**tool_args)
return ToolMessage(
tool_call_id=tool_call["id"], name=tool_name, content=content
)
with flyte.group(f"tool_calls_iteration_{iteration}"):
tool_messages = await asyncio.gather(
*[run_single_tool(tc) for tc in result.tool_calls]
)
# Add valid tool results to state
tool_messages = [msg for msg in tool_messages if msg]
state.messages.extend(convert_to_openai_messages(tool_messages))
iteration += 1
else:
# Reached iteration cap β optionally raise or log
print(f"Max iterations ({MAX_ITERATIONS}) reached for {type}")
return state
# {{/docs-fragment agent_helper}}
@env.task
async def create_fundamentals_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_fundamentals_openai]
else:
tools = [
toolkit.get_finnhub_company_insider_sentiment,
toolkit.get_finnhub_company_insider_transactions,
toolkit.get_simfin_balance_sheet,
toolkit.get_simfin_cashflow,
toolkit.get_simfin_income_stmt,
]
system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. "
"Please write a comprehensive report of the company's fundamental information such as financial documents, "
"company profile, basic company financials, company financial history, insider sentiment, and insider "
"transactions to gain a full view of the company's "
"fundamental information to inform traders. Make sure to include as much detail as possible. "
"Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions. "
"Make sure to append a Markdown table at the end of the report to organize key points in the report, "
"organized and easy to read."
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"fundamentals", state, llm, system_message, tool_names
)
@env.task
async def create_market_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_YFin_data_online,
toolkit.get_stockstats_indicators_report_online,
]
else:
tools = [
toolkit.get_YFin_data,
toolkit.get_stockstats_indicators_report,
]
system_message = (
"""You are a trading assistant tasked with analyzing financial markets.
Your role is to select the **most relevant indicators** for a given market condition
or trading strategy from the following list.
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
Categories and each category's indicators are:
Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator.
Usage: Identify trend direction and serve as dynamic support/resistance.
Tips: It lags price; combine with faster indicators for timely signals.
- close_200_sma: 200 SMA: A long-term trend benchmark.
Usage: Confirm overall market trend and identify golden/death cross setups.
Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
- close_10_ema: 10 EMA: A responsive short-term average.
Usage: Capture quick shifts in momentum and potential entry points.
Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.
MACD Related:
- macd: MACD: Computes momentum via differences of EMAs.
Usage: Look for crossovers and divergence as signals of trend changes.
Tips: Confirm with other indicators in low-volatility or sideways markets.
- macds: MACD Signal: An EMA smoothing of the MACD line.
Usage: Use crossovers with the MACD line to trigger trades.
Tips: Should be part of a broader strategy to avoid false positives.
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal.
Usage: Visualize momentum strength and spot divergence early.
Tips: Can be volatile; complement with additional filters in fast-moving markets.
Momentum Indicators:
- rsi: RSI: Measures momentum to flag overbought/oversold conditions.
Usage: Apply 70/30 thresholds and watch for divergence to signal reversals.
Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.
Volatility Indicators:
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands.
Usage: Acts as a dynamic benchmark for price movement.
Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line.
Usage: Signals potential overbought conditions and breakout zones.
Tips: Confirm signals with other tools; prices may ride the band in strong trends.
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line.
Usage: Indicates potential oversold conditions.
Tips: Use additional analysis to avoid false reversal signals.
- atr: ATR: Averages true range to measure volatility.
Usage: Set stop-loss levels and adjust position sizes based on current market volatility.
Tips: It's a reactive measure, so use it as part of a broader risk management strategy.
Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume.
Usage: Confirm trends by integrating price action with volume data.
Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information.
Avoid redundancy (e.g., do not select both rsi and stochrsi).
Also briefly explain why they are suitable for the given market context.
When you tool call, please use the exact name of the indicators provided above as they are defined parameters,
otherwise your call will fail.
Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators.
Write a very detailed and nuanced report of the trends you observe.
Do not simply state the trends are mixed, provide detailed and finegrained analysis
and insights that may help traders make decisions."""
""" Make sure to append a Markdown table at the end of the report to
organize key points in the report, organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("market", state, llm, system_message, tool_names)
# {{docs-fragment news_analyst}}
@env.task
async def create_news_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_global_news_openai,
toolkit.get_google_news,
]
else:
tools = [
toolkit.get_finnhub_news,
toolkit.get_reddit_news,
toolkit.get_google_news,
]
system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. "
"Please write a comprehensive report of the current state of the world that is relevant for "
"trading and macroeconomics. "
"Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Markdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("news", state, llm, system_message, tool_names)
# {{/docs-fragment news_analyst}}
@env.task
async def create_social_media_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_stock_news_openai]
else:
tools = [toolkit.get_reddit_stock_info]
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, "
"recent company news, and public sentiment for a specific company over the past week. "
"You will be given a company's name your objective is to write a comprehensive long report "
"detailing your analysis, insights, and implications for traders and investors on this company's current state "
"after looking at social media and what people are saying about that company, "
"analyzing sentiment data of what people feel each day about the company, and looking at recent company news. "
"Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends "
"are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"sentiment", state, llm, system_message, tool_names
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/analysts.py)
Each analyst agent uses a helper function to bind tools, iterate through reasoning steps (up to a configurable maximum), and produce an answer. Setting a max iteration count is crucial to prevent runaway loops. As agents reason, their message history is preserved in their internal state and passed along to the next agent in the chain.
```
import asyncio
from agents.utils.utils import AgentState
from flyte_env import env
from langchain_core.messages import ToolMessage, convert_to_openai_messages
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from tools import toolkit
import flyte
MAX_ITERATIONS = 5
# {{docs-fragment agent_helper}}
async def run_chain_with_tools(
type: str, state: AgentState, llm: str, system_message: str, tool_names: list[str]
) -> AgentState:
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
" For your reference, the current date is {current_date}. The company we want to look at is {ticker}.",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join(tool_names))
prompt = prompt.partial(current_date=state.trade_date)
prompt = prompt.partial(ticker=state.company_of_interest)
chain = prompt | ChatOpenAI(model=llm).bind_tools(
[getattr(toolkit, tool_name).func for tool_name in tool_names]
)
iteration = 0
while iteration < MAX_ITERATIONS:
result = await chain.ainvoke(state.messages)
state.messages.append(convert_to_openai_messages(result))
if not result.tool_calls:
# Final response β no tools required
setattr(state, f"{type}_report", result.content or "")
break
# Run all tool calls in parallel
async def run_single_tool(tool_call):
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool = getattr(toolkit, tool_name, None)
if not tool:
return None
content = await tool(**tool_args)
return ToolMessage(
tool_call_id=tool_call["id"], name=tool_name, content=content
)
with flyte.group(f"tool_calls_iteration_{iteration}"):
tool_messages = await asyncio.gather(
*[run_single_tool(tc) for tc in result.tool_calls]
)
# Add valid tool results to state
tool_messages = [msg for msg in tool_messages if msg]
state.messages.extend(convert_to_openai_messages(tool_messages))
iteration += 1
else:
# Reached iteration cap β optionally raise or log
print(f"Max iterations ({MAX_ITERATIONS}) reached for {type}")
return state
# {{/docs-fragment agent_helper}}
@env.task
async def create_fundamentals_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_fundamentals_openai]
else:
tools = [
toolkit.get_finnhub_company_insider_sentiment,
toolkit.get_finnhub_company_insider_transactions,
toolkit.get_simfin_balance_sheet,
toolkit.get_simfin_cashflow,
toolkit.get_simfin_income_stmt,
]
system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. "
"Please write a comprehensive report of the company's fundamental information such as financial documents, "
"company profile, basic company financials, company financial history, insider sentiment, and insider "
"transactions to gain a full view of the company's "
"fundamental information to inform traders. Make sure to include as much detail as possible. "
"Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions. "
"Make sure to append a Markdown table at the end of the report to organize key points in the report, "
"organized and easy to read."
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"fundamentals", state, llm, system_message, tool_names
)
@env.task
async def create_market_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_YFin_data_online,
toolkit.get_stockstats_indicators_report_online,
]
else:
tools = [
toolkit.get_YFin_data,
toolkit.get_stockstats_indicators_report,
]
system_message = (
"""You are a trading assistant tasked with analyzing financial markets.
Your role is to select the **most relevant indicators** for a given market condition
or trading strategy from the following list.
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
Categories and each category's indicators are:
Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator.
Usage: Identify trend direction and serve as dynamic support/resistance.
Tips: It lags price; combine with faster indicators for timely signals.
- close_200_sma: 200 SMA: A long-term trend benchmark.
Usage: Confirm overall market trend and identify golden/death cross setups.
Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
- close_10_ema: 10 EMA: A responsive short-term average.
Usage: Capture quick shifts in momentum and potential entry points.
Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.
MACD Related:
- macd: MACD: Computes momentum via differences of EMAs.
Usage: Look for crossovers and divergence as signals of trend changes.
Tips: Confirm with other indicators in low-volatility or sideways markets.
- macds: MACD Signal: An EMA smoothing of the MACD line.
Usage: Use crossovers with the MACD line to trigger trades.
Tips: Should be part of a broader strategy to avoid false positives.
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal.
Usage: Visualize momentum strength and spot divergence early.
Tips: Can be volatile; complement with additional filters in fast-moving markets.
Momentum Indicators:
- rsi: RSI: Measures momentum to flag overbought/oversold conditions.
Usage: Apply 70/30 thresholds and watch for divergence to signal reversals.
Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.
Volatility Indicators:
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands.
Usage: Acts as a dynamic benchmark for price movement.
Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line.
Usage: Signals potential overbought conditions and breakout zones.
Tips: Confirm signals with other tools; prices may ride the band in strong trends.
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line.
Usage: Indicates potential oversold conditions.
Tips: Use additional analysis to avoid false reversal signals.
- atr: ATR: Averages true range to measure volatility.
Usage: Set stop-loss levels and adjust position sizes based on current market volatility.
Tips: It's a reactive measure, so use it as part of a broader risk management strategy.
Volume-Based Indicators:
- vwma: VWMA: A moving average weighted by volume.
Usage: Confirm trends by integrating price action with volume data.
Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
- Select indicators that provide diverse and complementary information.
Avoid redundancy (e.g., do not select both rsi and stochrsi).
Also briefly explain why they are suitable for the given market context.
When you tool call, please use the exact name of the indicators provided above as they are defined parameters,
otherwise your call will fail.
Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators.
Write a very detailed and nuanced report of the trends you observe.
Do not simply state the trends are mixed, provide detailed and finegrained analysis
and insights that may help traders make decisions."""
""" Make sure to append a Markdown table at the end of the report to
organize key points in the report, organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("market", state, llm, system_message, tool_names)
# {{docs-fragment news_analyst}}
@env.task
async def create_news_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [
toolkit.get_global_news_openai,
toolkit.get_google_news,
]
else:
tools = [
toolkit.get_finnhub_news,
toolkit.get_reddit_news,
toolkit.get_google_news,
]
system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. "
"Please write a comprehensive report of the current state of the world that is relevant for "
"trading and macroeconomics. "
"Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, "
"provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Markdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools("news", state, llm, system_message, tool_names)
# {{/docs-fragment news_analyst}}
@env.task
async def create_social_media_analyst(
llm: str, state: AgentState, online_tools: bool
) -> AgentState:
if online_tools:
tools = [toolkit.get_stock_news_openai]
else:
tools = [toolkit.get_reddit_stock_info]
system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, "
"recent company news, and public sentiment for a specific company over the past week. "
"You will be given a company's name your objective is to write a comprehensive long report "
"detailing your analysis, insights, and implications for traders and investors on this company's current state "
"after looking at social media and what people are saying about that company, "
"analyzing sentiment data of what people feel each day about the company, and looking at recent company news. "
"Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends "
"are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
""" Make sure to append a Makrdown table at the end of the report to organize key points in the report,
organized and easy to read."""
)
tool_names = [tool.func.__name__ for tool in tools]
return await run_chain_with_tools(
"sentiment", state, llm, system_message, tool_names
)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/analysts.py)
Once all analyst reports are complete, their outputs are collected and passed to the next stage of the workflow.
### Research agents
The research phase consists of two agents: a bullish researcher and a bearish one. They evaluate the company from opposing viewpoints, drawing on the analysts' reports. Unlike analysts, they don't use tools. Their role is to interpret, critique, and develop positions based on the evidence.
```
from agents.utils.utils import AgentState, InvestmentDebateState, memory_init
from flyte_env import env
from langchain_openai import ChatOpenAI
# {{docs-fragment bear_researcher}}
@env.task
async def create_bear_researcher(llm: str, state: AgentState) -> AgentState:
investment_debate_state = state.investment_debate_state
history = investment_debate_state.history
bear_history = investment_debate_state.bear_history
current_response = investment_debate_state.current_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="bear-researcher")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bear Analyst making the case against investing in the stock.
Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators.
Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
Key points to focus on:
- Risks and Challenges: Highlight factors like market saturation, financial instability,
or macroeconomic threats that could hinder the stock's performance.
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation,
or threats from competitors.
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning,
exposing weaknesses or over-optimistic assumptions.
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points
and debating effectively rather than simply listing facts.
Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bull argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate
that demonstrates the risks and weaknesses of investing in the stock.
You must also address reflections and learn from lessons and mistakes you made in the past.
"""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Bear Analyst: {response.content}"
new_investment_debate_state = InvestmentDebateState(
history=history + "\n" + argument,
bear_history=bear_history + "\n" + argument,
bull_history=investment_debate_state.bull_history,
current_response=argument,
count=investment_debate_state.count + 1,
)
state.investment_debate_state = new_investment_debate_state
return state
# {{/docs-fragment bear_researcher}}
@env.task
async def create_bull_researcher(llm: str, state: AgentState) -> AgentState:
investment_debate_state = state.investment_debate_state
history = investment_debate_state.history
bull_history = investment_debate_state.bull_history
current_response = investment_debate_state.current_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="bull-researcher")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bull Analyst advocating for investing in the stock.
Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages,
and positive market indicators.
Leverage the provided research and data to address concerns and counter bearish arguments effectively.
Key points to focus on:
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing
concerns thoroughly and showing why the bull perspective holds stronger merit.
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points
and debating effectively rather than just listing data.
Resources available:
Market research report: {market_research_report}
Social media sentiment report: {sentiment_report}
Latest world affairs news: {news_report}
Company fundamentals report: {fundamentals_report}
Conversation history of the debate: {history}
Last bear argument: {current_response}
Reflections from similar situations and lessons learned: {past_memory_str}
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate
that demonstrates the strengths of the bull position.
You must also address reflections and learn from lessons and mistakes you made in the past.
"""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Bull Analyst: {response.content}"
new_investment_debate_state = InvestmentDebateState(
history=history + "\n" + argument,
bull_history=bull_history + "\n" + argument,
bear_history=investment_debate_state.bear_history,
current_response=argument,
count=investment_debate_state.count + 1,
)
state.investment_debate_state = new_investment_debate_state
return state
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/researchers.py)
To aid reasoning, the agents can also retrieve relevant "memories" from a vector database, giving them richer historical context. The number of debate rounds is configurable, and after a few iterations of back-and-forth between the bull and bear, a research manager agent reviews their arguments and makes a final investment decision.
```
from agents.utils.utils import (
AgentState,
InvestmentDebateState,
RiskDebateState,
memory_init,
)
from flyte_env import env
from langchain_openai import ChatOpenAI
# {{docs-fragment research_manager}}
@env.task
async def create_research_manager(llm: str, state: AgentState) -> AgentState:
history = state.investment_debate_state.history
investment_debate_state = state.investment_debate_state
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="research-manager")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate
this round of debate and make a definitive decision:
align with the bear analyst, the bull analyst,
or choose Hold only if it is strongly justified based on the arguments presented.
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning.
Your recommendationβBuy, Sell, or Holdβmust be clear and actionable.
Avoid defaulting to Hold simply because both sides have valid points;
commit to a stance grounded in the debate's strongest arguments.
Additionally, develop a detailed investment plan for the trader. This should include:
Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations.
Use these insights to refine your decision-making and ensure you are learning and improving.
Present your analysis conversationally, as if speaking naturally, without special formatting.
Here are your past reflections on mistakes:
\"{past_memory_str}\"
Here is the debate:
Debate History:
{history}"""
response = ChatOpenAI(model=llm).invoke(prompt)
new_investment_debate_state = InvestmentDebateState(
judge_decision=response.content,
history=investment_debate_state.history,
bear_history=investment_debate_state.bear_history,
bull_history=investment_debate_state.bull_history,
current_response=response.content,
count=investment_debate_state.count,
)
state.investment_debate_state = new_investment_debate_state
state.investment_plan = response.content
return state
# {{/docs-fragment research_manager}}
@env.task
async def create_risk_manager(llm: str, state: AgentState) -> AgentState:
history = state.risk_debate_state.history
risk_debate_state = state.risk_debate_state
trader_plan = state.investment_plan
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="risk-manager")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the Risk Management Judge and Debate Facilitator,
your goal is to evaluate the debate between three risk analystsβRisky,
Neutral, and Safe/Conservativeβand determine the best course of action for the trader.
Your decision must result in a clear recommendation: Buy, Sell, or Hold.
Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid.
Strive for clarity and decisiveness.
Guidelines for Decision-Making:
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**,
and adjust it based on the analysts' insights.
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments
and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
Deliverables:
- A clear and actionable recommendation: Buy, Sell, or Hold.
- Detailed reasoning anchored in the debate and past reflections.
---
**Analysts Debate History:**
{history}
---
Focus on actionable insights and continuous improvement.
Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
response = ChatOpenAI(model=llm).invoke(prompt)
new_risk_debate_state = RiskDebateState(
judge_decision=response.content,
history=risk_debate_state.history,
risky_history=risk_debate_state.risky_history,
safe_history=risk_debate_state.safe_history,
neutral_history=risk_debate_state.neutral_history,
latest_speaker="Judge",
current_risky_response=risk_debate_state.current_risky_response,
current_safe_response=risk_debate_state.current_safe_response,
current_neutral_response=risk_debate_state.current_neutral_response,
count=risk_debate_state.count,
)
state.risk_debate_state = new_risk_debate_state
state.final_trade_decision = response.content
return state
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/managers.py)
### Trading agent
The trader agent consolidates the insights from analysts and researchers to generate a final recommendation. It synthesizes competing signals and produces a conclusion such as _Buy for long-term growth despite short-term volatility_.
```
from agents.utils.utils import AgentState, memory_init
from flyte_env import env
from langchain_core.messages import convert_to_openai_messages
from langchain_openai import ChatOpenAI
# {{docs-fragment trader}}
@env.task
async def create_trader(llm: str, state: AgentState) -> AgentState:
company_name = state.company_of_interest
investment_plan = state.investment_plan
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
memory = await memory_init(name="trader")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for rec in past_memories:
past_memory_str += rec["recommendation"] + "\n\n"
context = {
"role": "user",
"content": f"Based on a comprehensive analysis by a team of analysts, "
f"here is an investment plan tailored for {company_name}. "
"This plan incorporates insights from current technical market trends, "
"macroeconomic indicators, and social media sentiment. "
"Use this plan as a foundation for evaluating your next trading decision.\n\n"
f"Proposed Investment Plan: {investment_plan}\n\n"
"Leverage these insights to make an informed and strategic decision.",
}
messages = [
{
"role": "system",
"content": f"""You are a trading agent analyzing market data to make investment decisions.
Based on your analysis, provide a specific recommendation to buy, sell, or hold.
End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**'
to confirm your recommendation.
Do not forget to utilize lessons from past decisions to learn from your mistakes.
Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
},
context,
]
result = ChatOpenAI(model=llm).invoke(messages)
state.messages.append(convert_to_openai_messages(result))
state.trader_investment_plan = result.content
state.sender = "Trader"
return state
# {{/docs-fragment trader}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/trader.py)
### Risk agents
Risk agents comprise agents with different risk tolerances: a risky debater, a neutral one, and a conservative one. They assess the portfolio through lenses like market volatility, liquidity, and systemic risk. Similar to the bull-bear debate, these agents engage in internal discussion, after which a risk manager makes the final call.
```
from agents.utils.utils import AgentState, RiskDebateState
from flyte_env import env
from langchain_openai import ChatOpenAI
# {{docs-fragment risk_debator}}
@env.task
async def create_risky_debator(llm: str, state: AgentState) -> AgentState:
risk_debate_state = state.risk_debate_state
history = risk_debate_state.history
risky_history = risk_debate_state.risky_history
current_safe_response = risk_debate_state.current_safe_response
current_neutral_response = risk_debate_state.current_neutral_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
trader_decision = state.trader_investment_plan
prompt = f"""As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities,
emphasizing bold strategies and competitive advantages.
When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential,
and innovative benefitsβeven when these come with elevated risk.
Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views.
Specifically, respond directly to each point made by the conservative and neutral analysts,
countering with data-driven rebuttals and persuasive reasoning.
Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative.
Here is the trader's decision:
{trader_decision}
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative
and neutral stances to demonstrate why your high-reward perspective offers the best path forward.
Incorporate insights from the following sources into your arguments:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here are the last arguments from the conservative analyst: {current_safe_response}
Here are the last arguments from the neutral analyst: {current_neutral_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic,
and asserting the benefits of risk-taking to outpace market norms.
Maintain a focus on debating and persuading, not just presenting data.
Challenge each counterpoint to underscore why a high-risk approach is optimal.
Output conversationally as if you are speaking without any special formatting."""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Risky Analyst: {response.content}"
new_risk_debate_state = RiskDebateState(
history=history + "\n" + argument,
risky_history=risky_history + "\n" + argument,
safe_history=risk_debate_state.safe_history,
neutral_history=risk_debate_state.neutral_history,
latest_speaker="Risky",
current_risky_response=argument,
current_safe_response=current_safe_response,
current_neutral_response=current_neutral_response,
count=risk_debate_state.count + 1,
)
state.risk_debate_state = new_risk_debate_state
return state
# {{/docs-fragment risk_debator}}
@env.task
async def create_safe_debator(llm: str, state: AgentState) -> AgentState:
risk_debate_state = state.risk_debate_state
history = risk_debate_state.history
safe_history = risk_debate_state.safe_history
current_risky_response = risk_debate_state.current_risky_response
current_neutral_response = risk_debate_state.current_neutral_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
trader_decision = state.trader_investment_plan
prompt = f"""As the Safe/Conservative Risk Analyst, your primary objective is to protect assets,
minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation,
carefully assessing potential losses, economic downturns, and market volatility.
When evaluating the trader's decision or plan, critically examine high-risk elements,
pointing out where the decision may expose the firm to undue risk and where more cautious
alternatives could secure long-term gains.
Here is the trader's decision:
{trader_decision}
Your task is to actively counter the arguments of the Risky and Neutral Analysts,
highlighting where their views may overlook potential threats or fail to prioritize sustainability.
Respond directly to their points, drawing from the following data sources
to build a convincing case for a low-risk approach adjustment to the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here is the last response from the risky analyst: {current_risky_response}
Here is the last response from the neutral analyst: {current_neutral_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked.
Address each of their counterpoints to showcase why a conservative stance is ultimately the
safest path for the firm's assets.
Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy
over their approaches.
Output conversationally as if you are speaking without any special formatting."""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Safe Analyst: {response.content}"
new_risk_debate_state = RiskDebateState(
history=history + "\n" + argument,
risky_history=risk_debate_state.risky_history,
safe_history=safe_history + "\n" + argument,
neutral_history=risk_debate_state.neutral_history,
latest_speaker="Safe",
current_risky_response=current_risky_response,
current_safe_response=argument,
current_neutral_response=current_neutral_response,
count=risk_debate_state.count + 1,
)
state.risk_debate_state = new_risk_debate_state
return state
@env.task
async def create_neutral_debator(llm: str, state: AgentState) -> AgentState:
risk_debate_state = state.risk_debate_state
history = risk_debate_state.history
neutral_history = risk_debate_state.neutral_history
current_risky_response = risk_debate_state.current_risky_response
current_safe_response = risk_debate_state.current_safe_response
market_research_report = state.market_report
sentiment_report = state.sentiment_report
news_report = state.news_report
fundamentals_report = state.fundamentals_report
trader_decision = state.trader_investment_plan
prompt = f"""As the Neutral Risk Analyst, your role is to provide a balanced perspective,
weighing both the potential benefits and risks of the trader's decision or plan.
You prioritize a well-rounded approach, evaluating the upsides
and downsides while factoring in broader market trends,
potential economic shifts, and diversification strategies.Here is the trader's decision:
{trader_decision}
Your task is to challenge both the Risky and Safe Analysts,
pointing out where each perspective may be overly optimistic or overly cautious.
Use insights from the following data sources to support a moderate, sustainable strategy
to adjust the trader's decision:
Market Research Report: {market_research_report}
Social Media Sentiment Report: {sentiment_report}
Latest World Affairs Report: {news_report}
Company Fundamentals Report: {fundamentals_report}
Here is the current conversation history: {history}
Here is the last response from the risky analyst: {current_risky_response}
Here is the last response from the safe analyst: {current_safe_response}.
If there are no responses from the other viewpoints, do not halluncinate and just present your point.
Engage actively by analyzing both sides critically, addressing weaknesses in the risky
and conservative arguments to advocate for a more balanced approach.
Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds,
providing growth potential while safeguarding against extreme volatility.
Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to
the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
response = ChatOpenAI(model=llm).invoke(prompt)
argument = f"Neutral Analyst: {response.content}"
new_risk_debate_state = RiskDebateState(
history=history + "\n" + argument,
risky_history=risk_debate_state.risky_history,
safe_history=risk_debate_state.safe_history,
neutral_history=neutral_history + "\n" + argument,
latest_speaker="Neutral",
current_risky_response=current_risky_response,
current_safe_response=current_safe_response,
current_neutral_response=argument,
count=risk_debate_state.count + 1,
)
state.risk_debate_state = new_risk_debate_state
return state
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/agents/risk_debators.py)
The outcome of the risk manager β whether to proceed with the trade or not β is considered the final decision of the trading simulation.
You can visualize this full pipeline in the Flyte/Union UI, where every step is logged.
Youβll see input/output metadata for each tool and agent task.
Thanks to Flyte's caching, repeated steps are skipped unless inputs change, saving time and compute resources.
### Retaining agent memory with S3 vectors
To help agents learn from past decisions, we persist their memory in a vector store. In this example, we use an [S3 vector](https://aws.amazon.com/s3/features/vectors/) bucket for their simplicity and tight integration with Flyte and Union, but any vector database can be used.
Note: To use the S3 vector store, make sure your IAM role has the following permissions configured:
```
s3vectors:CreateVectorBucket
s3vectors:CreateIndex
s3vectors:PutVectors
s3vectors:GetIndex
s3vectors:GetVectors
s3vectors:QueryVectors
s3vectors:GetVectorBucket
```
After each trade decision, you can run a `reflect_on_decisions` task. This evaluates whether the final outcome aligned with the agent's recommendation and stores that reflection in the vector store. These stored insights can later be retrieved to provide historical context and improve future decision-making.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "akshare==1.16.98",
# "backtrader==1.9.78.123",
# "boto3==1.39.9",
# "chainlit==2.5.5",
# "eodhd==1.0.32",
# "feedparser==6.0.11",
# "finnhub-python==2.4.23",
# "langchain-experimental==0.3.4",
# "langchain-openai==0.3.23",
# "pandas==2.3.0",
# "parsel==1.10.0",
# "praw==7.8.1",
# "pytz==2025.2",
# "questionary==2.1.0",
# "redis==6.2.0",
# "requests==2.32.4",
# "stockstats==0.6.5",
# "tqdm==4.67.1",
# "tushare==1.4.21",
# "typing-extensions==4.14.0",
# "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy
import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
create_neutral_debator,
create_risky_debator,
create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
reflect_bear_researcher,
reflect_bull_researcher,
reflect_research_manager,
reflect_risk_manager,
reflect_trader,
)
@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
"""Process a full trading signal to extract the core decision."""
messages = [
{
"role": "system",
"content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
},
{"role": "human", "content": full_signal},
]
return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content
async def run_analyst(analyst_name, state, online_tools):
# Create a copy of the state for isolation
run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")
# Run the analyst's chain
result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)
# Determine the report key
report_key = (
"sentiment_report"
if analyst_name == "social_media"
else f"{analyst_name}_report"
)
report_value = getattr(result_state, report_key)
return result_state.messages[1:], report_key, report_value
# {{docs-fragment main}}
@env.task
async def main(
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
if not selected_analysts:
raise ValueError(
"No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
)
state = AgentState(
messages=[{"role": "human", "content": company_name}],
company_of_interest=company_name,
trade_date=str(trade_date),
)
# Run all analysts concurrently
results = await asyncio.gather(
*[
run_analyst(analyst, deepcopy(state), online_tools)
for analyst in selected_analysts
]
)
# Flatten and append all resulting messages into the shared state
for messages, report_attr, report in results:
state.messages.extend(messages)
setattr(state, report_attr, report)
# Bull/Bear debate loop
state = await create_bull_researcher(QUICK_THINKING_LLM, state) # Start with bull
while state.investment_debate_state.count < 2 * max_debate_rounds:
current = state.investment_debate_state.current_response
if current.startswith("Bull"):
state = await create_bear_researcher(QUICK_THINKING_LLM, state)
else:
state = await create_bull_researcher(QUICK_THINKING_LLM, state)
state = await create_research_manager(DEEP_THINKING_LLM, state)
state = await create_trader(QUICK_THINKING_LLM, state)
# Risk debate loop
state = await create_risky_debator(QUICK_THINKING_LLM, state) # Start with risky
while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
speaker = state.risk_debate_state.latest_speaker
if speaker == "Risky":
state = await create_safe_debator(QUICK_THINKING_LLM, state)
elif speaker == "Safe":
state = await create_neutral_debator(QUICK_THINKING_LLM, state)
else:
state = await create_risky_debator(QUICK_THINKING_LLM, state)
state = await create_risk_manager(DEEP_THINKING_LLM, state)
decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)
return decision, state
# {{/docs-fragment main}}
# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
await asyncio.gather(
reflect_bear_researcher(state, returns),
reflect_bull_researcher(state, returns),
reflect_trader(state, returns),
reflect_risk_manager(state, returns),
reflect_research_manager(state, returns),
)
return "Reflection completed."
# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
returns: str,
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> str:
_, state = await main(
selected_analysts,
max_debate_rounds,
max_risk_discuss_rounds,
online_tools,
company_name,
trade_date,
)
return await reflect_and_store(state, returns)
# {{/docs-fragment reflect_on_decisions}}
# {{docs-fragment execute_main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
# print(run.url)
# {{/docs-fragment execute_main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py)
### Running the simulation
First, set up your OpenAI secret (from [openai.com](https://platform.openai.com/api-keys)) and Finnhub API key (from [finnhub.io](https://finnhub.io/)):
```
flyte create secret openai_api_key
flyte create secret finnhub_api_key
```
Then [clone the repo](https://github.com/unionai/unionai-examples), navigate to the `tutorials-v2/trading_agents` directory, and run the following commands:
```
flyte create config --endpoint --project --domain --builder remote
uv run main.py
```
If you'd like to run the `reflect_on_decisions` task instead, comment out the `main` function call and uncomment the `reflect_on_decisions` call in the `__main__` block:
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "akshare==1.16.98",
# "backtrader==1.9.78.123",
# "boto3==1.39.9",
# "chainlit==2.5.5",
# "eodhd==1.0.32",
# "feedparser==6.0.11",
# "finnhub-python==2.4.23",
# "langchain-experimental==0.3.4",
# "langchain-openai==0.3.23",
# "pandas==2.3.0",
# "parsel==1.10.0",
# "praw==7.8.1",
# "pytz==2025.2",
# "questionary==2.1.0",
# "redis==6.2.0",
# "requests==2.32.4",
# "stockstats==0.6.5",
# "tqdm==4.67.1",
# "tushare==1.4.21",
# "typing-extensions==4.14.0",
# "yfinance==0.2.63",
# ]
# main = "main"
# params = ""
# ///
import asyncio
from copy import deepcopy
import agents
import agents.analysts
from agents.managers import create_research_manager, create_risk_manager
from agents.researchers import create_bear_researcher, create_bull_researcher
from agents.risk_debators import (
create_neutral_debator,
create_risky_debator,
create_safe_debator,
)
from agents.trader import create_trader
from agents.utils.utils import AgentState
from flyte_env import DEEP_THINKING_LLM, QUICK_THINKING_LLM, env, flyte
from langchain_openai import ChatOpenAI
from reflection import (
reflect_bear_researcher,
reflect_bull_researcher,
reflect_research_manager,
reflect_risk_manager,
reflect_trader,
)
@env.task
async def process_signal(full_signal: str, QUICK_THINKING_LLM: str) -> str:
"""Process a full trading signal to extract the core decision."""
messages = [
{
"role": "system",
"content": """You are an efficient assistant designed to analyze paragraphs or
financial reports provided by a group of analysts.
Your task is to extract the investment decision: SELL, BUY, or HOLD.
Provide only the extracted decision (SELL, BUY, or HOLD) as your output,
without adding any additional text or information.""",
},
{"role": "human", "content": full_signal},
]
return ChatOpenAI(model=QUICK_THINKING_LLM).invoke(messages).content
async def run_analyst(analyst_name, state, online_tools):
# Create a copy of the state for isolation
run_fn = getattr(agents.analysts, f"create_{analyst_name}_analyst")
# Run the analyst's chain
result_state = await run_fn(QUICK_THINKING_LLM, state, online_tools)
# Determine the report key
report_key = (
"sentiment_report"
if analyst_name == "social_media"
else f"{analyst_name}_report"
)
report_value = getattr(result_state, report_key)
return result_state.messages[1:], report_key, report_value
# {{docs-fragment main}}
@env.task
async def main(
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> tuple[str, AgentState]:
if not selected_analysts:
raise ValueError(
"No analysts selected. Please select at least one analyst from market, fundamentals, news, or social_media."
)
state = AgentState(
messages=[{"role": "human", "content": company_name}],
company_of_interest=company_name,
trade_date=str(trade_date),
)
# Run all analysts concurrently
results = await asyncio.gather(
*[
run_analyst(analyst, deepcopy(state), online_tools)
for analyst in selected_analysts
]
)
# Flatten and append all resulting messages into the shared state
for messages, report_attr, report in results:
state.messages.extend(messages)
setattr(state, report_attr, report)
# Bull/Bear debate loop
state = await create_bull_researcher(QUICK_THINKING_LLM, state) # Start with bull
while state.investment_debate_state.count < 2 * max_debate_rounds:
current = state.investment_debate_state.current_response
if current.startswith("Bull"):
state = await create_bear_researcher(QUICK_THINKING_LLM, state)
else:
state = await create_bull_researcher(QUICK_THINKING_LLM, state)
state = await create_research_manager(DEEP_THINKING_LLM, state)
state = await create_trader(QUICK_THINKING_LLM, state)
# Risk debate loop
state = await create_risky_debator(QUICK_THINKING_LLM, state) # Start with risky
while state.risk_debate_state.count < 3 * max_risk_discuss_rounds:
speaker = state.risk_debate_state.latest_speaker
if speaker == "Risky":
state = await create_safe_debator(QUICK_THINKING_LLM, state)
elif speaker == "Safe":
state = await create_neutral_debator(QUICK_THINKING_LLM, state)
else:
state = await create_risky_debator(QUICK_THINKING_LLM, state)
state = await create_risk_manager(DEEP_THINKING_LLM, state)
decision = await process_signal(state.final_trade_decision, QUICK_THINKING_LLM)
return decision, state
# {{/docs-fragment main}}
# {{docs-fragment reflect_on_decisions}}
@env.task
async def reflect_and_store(state: AgentState, returns: str) -> str:
await asyncio.gather(
reflect_bear_researcher(state, returns),
reflect_bull_researcher(state, returns),
reflect_trader(state, returns),
reflect_risk_manager(state, returns),
reflect_research_manager(state, returns),
)
return "Reflection completed."
# Run the reflection task after the main function
@env.task(cache="disable")
async def reflect_on_decisions(
returns: str,
selected_analysts: list[str] = [
"market",
"fundamentals",
"news",
"social_media",
],
max_debate_rounds: int = 1,
max_risk_discuss_rounds: int = 1,
online_tools: bool = True,
company_name: str = "NVDA",
trade_date: str = "2024-05-12",
) -> str:
_, state = await main(
selected_analysts,
max_debate_rounds,
max_risk_discuss_rounds,
online_tools,
company_name,
trade_date,
)
return await reflect_and_store(state, returns)
# {{/docs-fragment reflect_on_decisions}}
# {{docs-fragment execute_main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# run = flyte.run(reflect_on_decisions, "+3.2% gain over 5 days")
# print(run.url)
# {{/docs-fragment execute_main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/trading_agents/main.py)
Then run:
```
uv run --prerelease=allow main.py
```
## Why Flyte? _(A quick note before you go)_
You might now be wondering: can't I just build all this with Python and LangChain?
Absolutely. But as your project grows, you'll likely run into these challenges:
1. **Observability**: Agent workflows can feel opaque. You send a prompt, get a response, but what happened in between?
- Were the right tools used?
- Were correct arguments passed?
- How did the LLM reason through intermediate steps?
- Why did it fail?
Flyte gives you a window into each of these stages.
2. **Multi-agent coordination**: Real-world applications often require multiple agents with distinct roles and responsibilities. In such cases, you'll need:
- Isolated state per agent,
- Shared context where needed,
- And coordination β sequential or parallel.
Managing this manually gets fragile, fast. Flyte handles it for you.
3. **Scalability**: Agents and tools might need to run in isolated or containerized environments. Whether you're scaling out to more agents or more powerful hardware, Flyte lets you scale without taxing your local machine or racking up unnecessary cloud bills.
4. **Durability & recovery**: LLM-based workflows are often long-running and expensive. If something fails halfway:
- Do you lose all progress?
- Replay everything from scratch?
With Flyte, you get built-in caching, checkpointing, and recovery, so you can resume where you left off.
=== PAGE: https://www.union.ai/docs/v2/byoc/tutorials/code-agent ===
# Run LLM-generated code
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/code_runner).
This example demonstrates how to run code generated by a large language model (LLM) using a `ContainerTask`.
The agent takes a userβs question, generates Flyte 2 code using the Flyte 2 documentation as context, and runs it in an isolated container.
If the execution fails, the agent reflects on the error and retries
up to a configurable limit until it succeeds.
Using `ContainerTask` ensures that all generated code runs in a secure environment.
This gives you full flexibility to execute arbitrary logic safely and reliably.
## What this example demonstrates
- How to combine LLM generation with programmatic execution.
- How to run untrusted or dynamically generated code securely.
- How to iteratively improve code using agent-like behavior.
## Setting up the agent environment
Let's start by importing the necessary libraries and setting up two environments: one for the container task and another for the agent task.
This example follows the `uv` script format to declare dependencies.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b23",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# ///
```
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
> [!NOTE]
> You can set up access to the OpenAI API using a Flyte secret.
>
> ```
> flyte create secret openai_api_key
> ```
We store the LLM-generated code in a structured format. This allows us to:
- Enforce consistent formatting
- Make debugging easier
- Log and analyze generations systematically
By capturing metadata alongside the raw code, we maintain transparency and make it easier to iterate or trace issues over time.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
We then define a state model to persist the agent's history across iterations. This includes previous messages,
generated code, and any errors encountered.
Maintaining this history allows the agent to reflect on past attempts, avoid repeating mistakes,
and iteratively improve the generated code.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
## Retrieve docs
We define a task to load documents from a given URL and concatenate them into a single string.
This string is then used as part of the LLM prompt.
We set `max_depth = 20` to avoid loading an excessive number of documents.
However, even with this limit, the resulting context can still be quite large.
To handle this, we use an LLM (GPT-4 in this case) that supports extended context windows.
> [!NOTE]
> Appending all documents into a single string can result in extremely large contexts, potentially exceeding the LLMβs token limit.
> If your dataset grows beyond what a single prompt can handle, there are a couple of strategies you can use.
> One option is to apply Retrieval-Augmented Generation (RAG), where you chunk the documents, embed them using a model,
> store the vectors in a vector database, and retrieve only the most relevant pieces at inference time.
>
> An alternative approach is to pass references to full files into the prompt, allowing the LLM to decide which files are most relevant based
> on natural-language search over file paths, summaries, or even contents. This method assumes that only a subset of files
> will be necessary for a given task, and the LLM is responsible for navigating the structure and identifying what to read.
> While this can be a lighter-weight solution for smaller datasets, its effectiveness depends on how well the LLM can
> reason over file references and the reliability of its internal search heuristics.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
## Code generation
Next, we define a utility function to construct the LLM chain responsible for generating Python code from user input. This chain leverages
a LangChain `PromptTemplate` to structure the input and an OpenAI chat model to generate well-formed, Flyte 2-compatible Python scripts.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
We then define a `generate` task responsible for producing the code solution.
To improve clarity and testability, the output is structured in three parts:
a short summary of the generated solution, a list of necessary imports,
and the main body of executable code.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
A `ContainerTask` then executes this code in an isolated container environment.
It takes the code as input, runs it safely, and returns the programβs output and exit code.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
This task verifies that the generated code runs as expected.
It tests the import statements first, then executes the full code.
It records the output and any error messages in the agent state for further analysis.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
If an error occurs, a separate task reflects on the failure and generates a response.
This reflection is added to the agent state to guide future iterations.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
Finally, we define a `main` task that runs the code agent and orchestrates the steps above.
If the code execution fails, we reflect on the error and retry until we reach the maximum number of iterations.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "langchain-core==0.3.66",
# "langchain-openai==0.3.24",
# "langchain-community==0.3.26",
# "beautifulsoup4==4.13.4",
# "docker==7.1.0",
# ]
# main = "main"
# params = ""
# ///
# {{docs-fragment code_runner_task}}
import flyte
from flyte.extras import ContainerTask
from flyte.io import File
code_runner_task = ContainerTask(
name="run_flyte_v2",
image=flyte.Image.from_debian_base(),
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs={"script": File},
outputs={"result": str, "exit_code": str},
command=[
"/bin/bash",
"-c",
(
"set -o pipefail && "
"uv run --script /var/inputs/script > /var/outputs/result 2>&1; "
"echo $? > /var/outputs/exit_code"
),
],
resources=flyte.Resources(cpu=1, memory="1Gi"),
)
# {{/docs-fragment code_runner_task}}
# {{docs-fragment env}}
import tempfile
from typing import Optional
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
container_env = flyte.TaskEnvironment.from_task(
"code-runner-container", code_runner_task
)
env = flyte.TaskEnvironment(
name="code_runner",
secrets=[flyte.Secret(key="openai_api_key", as_env_var="OPENAI_API_KEY")],
image=flyte.Image.from_uv_script(__file__, name="code-runner-agent"),
resources=flyte.Resources(cpu=1),
depends_on=[container_env],
)
# {{/docs-fragment env}}
# {{docs-fragment code_base_model}}
class Code(BaseModel):
"""Schema for code solutions to questions about Flyte v2."""
prefix: str = Field(
default="", description="Description of the problem and approach"
)
imports: str = Field(
default="", description="Code block with just import statements"
)
code: str = Field(
default="", description="Code block not including import statements"
)
# {{/docs-fragment code_base_model}}
# {{docs-fragment agent_state}}
class AgentState(BaseModel):
messages: list[dict[str, str]] = Field(default_factory=list)
generation: Code = Field(default_factory=Code)
iterations: int = 0
error: str = "no"
output: Optional[str] = None
# {{/docs-fragment agent_state}}
# {{docs-fragment generate_code_gen_chain}}
async def generate_code_gen_chain(debug: bool) -> Runnable:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""
You are a coding assistant with expertise in Python.
You are able to execute the Flyte v2 code locally in a sandbox environment.
Use the following pattern to execute the code:
if __name__ == "__main__":
flyte.init_from_config()
print(flyte.run(...))
Your response will be shown to the user.
Here is a full set of documentation:
-------
{context}
-------
Answer the user question based on the above provided documentation.
Ensure any code you provide can be executed with all required imports and variables defined.
Structure your answer with a description of the code solution.
Then list the imports. And finally list the functioning code block.
Here is the user question:""",
),
("placeholder", "{messages}"),
]
)
expt_llm = "gpt-4o" if not debug else "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain = code_gen_prompt | llm.with_structured_output(Code)
return code_gen_chain
# {{/docs-fragment generate_code_gen_chain}}
# {{docs-fragment docs_retriever}}
@env.task
async def docs_retriever(url: str) -> str:
from bs4 import BeautifulSoup
from langchain_community.document_loaders.recursive_url_loader import (
RecursiveUrlLoader,
)
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: BeautifulSoup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
return concatenated_content
# {{/docs-fragment docs_retriever}}
# {{docs-fragment generate}}
@env.task
async def generate(
question: str, state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Generate a code solution
Args:
question (str): The user question
state (dict): The current graph state
concatenated_content (str): The concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
messages = state.messages
iterations = state.iterations
error = state.error
# We have been routed back to generation with an error
if error == "yes":
messages += [
{
"role": "user",
"content": (
"Now, try again. Invoke the code tool to structure the output "
"with a prefix, imports, and code block:"
),
}
]
code_gen_chain = await generate_code_gen_chain(debug)
# Solution
code_solution = code_gen_chain.invoke(
{
"context": concatenated_content,
"messages": (
messages if messages else [{"role": "user", "content": question}]
),
}
)
messages += [
{
"role": "assistant",
"content": f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
}
]
return AgentState(
messages=messages,
generation=code_solution,
iterations=iterations + 1,
error=error,
output=state.output,
)
# {{/docs-fragment generate}}
# {{docs-fragment code_check}}
@env.task
async def code_check(state: AgentState) -> AgentState:
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state.messages
code_solution = state.generation
iterations = state.iterations
# Get solution components
imports = code_solution.imports.strip()
code = code_solution.code.strip()
# Create temp file for imports
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", delete=False
) as imports_file:
imports_file.write(imports + "\n")
imports_path = imports_file.name
# Create temp file for code body
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as code_file:
code_file.write(imports + "\n" + code + "\n")
code_path = code_file.name
# Check imports
import_output, import_exit_code = await code_runner_task(
script=await File.from_local(imports_path)
)
if import_exit_code.strip() != "0":
print("---CODE IMPORT CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the import test: {import_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=import_output,
)
else:
print("---CODE IMPORT CHECK: PASSED---")
# Check execution
code_output, code_exit_code = await code_runner_task(
script=await File.from_local(code_path)
)
if code_exit_code.strip() != "0":
print("---CODE BLOCK CHECK: FAILED---")
error_message = [
{
"role": "user",
"content": f"Your solution failed the code execution test: {code_output}",
}
]
messages += error_message
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="yes",
output=code_output,
)
else:
print("---CODE BLOCK CHECK: PASSED---")
# No errors
print("---NO CODE TEST FAILURES---")
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error="no",
output=code_output,
)
# {{/docs-fragment code_check}}
# {{docs-fragment reflect}}
@env.task
async def reflect(
state: AgentState, concatenated_content: str, debug: bool
) -> AgentState:
"""
Reflect on errors
Args:
state (dict): The current graph state
concatenated_content (str): Concatenated docs content
debug (bool): Debug mode
Returns:
state (dict): New key added to state, reflection
"""
print("---REFLECTING---")
# State
messages = state.messages
iterations = state.iterations
code_solution = state.generation
# Prompt reflection
code_gen_chain = await generate_code_gen_chain(debug)
# Add reflection
reflections = code_gen_chain.invoke(
{"context": concatenated_content, "messages": messages}
)
messages += [
{
"role": "assistant",
"content": f"Here are reflections on the error: {reflections}",
}
]
return AgentState(
generation=code_solution,
messages=messages,
iterations=iterations,
error=state.error,
output=state.output,
)
# {{/docs-fragment reflect}}
# {{docs-fragment main}}
@env.task
async def main(
question: str = (
"Define a two-task pattern where the second catches OOM from the first and retries with more memory."
),
url: str = "https://pre-release-v2.docs-builder.pages.dev/docs/byoc/user-guide/",
max_iterations: int = 3,
debug: bool = False,
) -> str:
concatenated_content = await docs_retriever(url=url)
state: AgentState = AgentState()
iterations = 0
while True:
with flyte.group(f"code-generation-pass-{iterations + 1}"):
state = await generate(question, state, concatenated_content, debug)
state = await code_check(state)
error = state.error
iterations = state.iterations
if error == "no" or iterations >= max_iterations:
print("---DECISION: FINISH---")
code_solution = state.generation
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
code_output = state.output
return f"""{prefix}
{imports}
{code}
Result of code execution:
{code_output}
"""
else:
print("---DECISION: RE-TRY SOLUTION---")
state = await reflect(state, concatenated_content, debug)
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/code_runner/agent.py)
## Running the code agent
If things are working properly, you should see output similar to the following:
```
---GENERATING CODE SOLUTION---
---CHECKING CODE---
---CODE BLOCK CHECK: PASSED---
---NO CODE TEST FAILURES---
---DECISION: FINISH---
In this solution, we define two tasks using Flyte v2.
The first task, `oomer`, is designed to simulate an out-of-memory (OOM) error by attempting to allocate a large list.
The second task, `failure_recovery`, attempts to execute `oomer` and catches any OOM errors.
If an OOM error is caught, it retries the `oomer` task with increased memory resources.
This pattern demonstrates how to handle resource-related exceptions and dynamically adjust task configurations in Flyte workflows.
import asyncio
import flyte
import flyte.errors
env = flyte.TaskEnvironment(name="oom_example", resources=flyte.Resources(cpu=1, memory="250Mi"))
@env.task
async def oomer(x: int):
large_list = [0] * 100000000 # Simulate OOM
print(len(large_list))
@env.task
async def always_succeeds() -> int:
await asyncio.sleep(1)
return 42
...
```
You can run the code agent on a Flyte/Union cluster using the following command:
```
uv run --prerelease=allow agent.py
```
=== PAGE: https://www.union.ai/docs/v2/byoc/tutorials/text_to_sql ===
# Text-to-SQL
> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/text_to_sql); based on work by [LlamaIndex](https://docs.llamaindex.ai/en/stable/examples/workflow/advanced_text_to_sql/).
Data analytics drives modern decision-making, but SQL often creates a bottleneck. Writing queries requires technical expertise, so non-technical stakeholders must rely on data teams. That translation layer slows everyone down.
Text-to-SQL narrows this gap by turning natural language into executable SQL queries. It lowers the barrier to structured data and makes databases accessible to more people.
In this tutorial, we build a Text-to-SQL workflow using LlamaIndex and evaluate it on the [WikiTableQuestions dataset](https://ppasupat.github.io/WikiTableQuestions/) (a benchmark of natural language questions over semi-structured tables). We then explore prompt optimization to see whether it improves accuracy and show how to track prompts and results over time. Along the way, we'll see what worked, what didn't, and what we learned about building durable evaluation pipelines. The pattern here can be adapted to your own datasets and workflows.

## Ingesting data
We start by ingesting the WikiTableQuestions dataset, which comes as CSV files, into a SQLite database. This database serves as the source of truth for our Text-to-SQL pipeline.
```
import asyncio
import fnmatch
import os
import re
import zipfile
import flyte
import pandas as pd
import requests
from flyte.io import Dir, File
from llama_index.core.llms import ChatMessage
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.llms.openai import OpenAI
from pydantic import BaseModel, Field
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from utils import env
# {{docs-fragment table_info}}
class TableInfo(BaseModel):
"""Information regarding a structured table."""
table_name: str = Field(..., description="table name (underscores only, no spaces)")
table_summary: str = Field(
..., description="short, concise summary/caption of the table"
)
# {{/docs-fragment table_info}}
@env.task
async def download_and_extract(zip_path: str, search_glob: str) -> Dir:
"""Download and extract the dataset zip file if not already available."""
output_zip = "data.zip"
extract_dir = "wiki_table_questions"
if not os.path.exists(zip_path):
response = requests.get(zip_path, stream=True)
with open(output_zip, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
else:
output_zip = zip_path
print(f"Using existing file {output_zip}")
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(output_zip, "r") as zip_ref:
for member in zip_ref.namelist():
if fnmatch.fnmatch(member, search_glob):
zip_ref.extract(member, extract_dir)
remote_dir = await Dir.from_local(extract_dir)
return remote_dir
async def read_csv_file(
csv_file: File, nrows: int | None = None
) -> pd.DataFrame | None:
"""Safely download and parse a CSV file into a DataFrame."""
try:
local_csv_file = await csv_file.download()
return pd.read_csv(local_csv_file, nrows=nrows)
except Exception as e:
print(f"Error parsing {csv_file.path}: {e}")
return None
def sanitize_column_name(col_name: str) -> str:
"""Sanitize column names by replacing spaces/special chars with underscores."""
return re.sub(r"\W+", "_", col_name)
async def create_table_from_dataframe(
df: pd.DataFrame, table_name: str, engine, metadata_obj
):
"""Create a SQL table from a Pandas DataFrame."""
# Sanitize column names
sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
df = df.rename(columns=sanitized_columns)
# Define table columns based on DataFrame dtypes
columns = [
Column(col, String if dtype == "object" else Integer)
for col, dtype in zip(df.columns, df.dtypes)
]
table = Table(table_name, metadata_obj, *columns)
# Create table in database
metadata_obj.create_all(engine)
# Insert data into table
with engine.begin() as conn:
for _, row in df.iterrows():
conn.execute(table.insert().values(**row.to_dict()))
@flyte.trace
async def create_table(
csv_file: File, table_info: TableInfo, database_path: str
) -> str:
"""Safely create a table from CSV if parsing succeeds."""
df = await read_csv_file(csv_file)
if df is None:
return "false"
print(f"Creating table: {table_info.table_name}")
engine = create_engine(f"sqlite:///{database_path}")
metadata_obj = MetaData()
await create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
return "true"
@flyte.trace
async def llm_structured_predict(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
feedback: str,
llm: OpenAI,
) -> TableInfo:
return llm.structured_predict(
TableInfo,
prompt_tmpl,
feedback=feedback,
table_str=df_str,
exclude_table_name_list=str(list(table_names)),
)
async def generate_unique_table_info(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
llm: OpenAI,
tablename_lock: asyncio.Lock,
retries: int = 3,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
last_table_name = None
for attempt in range(retries):
feedback = ""
if attempt > 0:
feedback = f"Note: '{last_table_name}' already exists. Please pick a new name not in {table_names}."
table_info = await llm_structured_predict(
df_str, table_names, prompt_tmpl, feedback, llm
)
last_table_name = table_info.table_name
async with tablename_lock:
if table_info.table_name not in table_names:
table_names.append(table_info.table_name)
return table_info
print(f"Table name {table_info.table_name} already exists, retrying...")
return None
async def process_csv_file(
csv_file: File,
table_names: list[str],
semaphore: asyncio.Semaphore,
tablename_lock: asyncio.Lock,
llm: OpenAI,
prompt_tmpl: ChatPromptTemplate,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
async with semaphore:
df = await read_csv_file(csv_file, nrows=10)
if df is None:
return None
return await generate_unique_table_info(
df.to_csv(), table_names, prompt_tmpl, llm, tablename_lock
)
@env.task
async def extract_table_info(
data_dir: Dir, model: str, concurrency: int
) -> list[TableInfo | None]:
"""Extract structured table information from CSV files."""
table_names: list[str] = []
semaphore = asyncio.Semaphore(concurrency)
tablename_lock = asyncio.Lock()
llm = OpenAI(model=model)
prompt_str = """\
Provide a JSON object with the following fields:
- `table_name`: must be unique and descriptive (underscores only, no generic names).
- `table_summary`: short and concise summary of the table.
Do NOT use any of these table names: {exclude_table_name_list}
Table:
{table_str}
{feedback}
"""
prompt_tmpl = ChatPromptTemplate(
message_templates=[ChatMessage.from_str(prompt_str, role="user")]
)
tasks = [
process_csv_file(
csv_file, table_names, semaphore, tablename_lock, llm, prompt_tmpl
)
async for csv_file in data_dir.walk()
]
return await asyncio.gather(*tasks)
# {{docs-fragment data_ingestion}}
@env.task
async def data_ingestion(
csv_zip_path: str = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob: str = "WikiTableQuestions/csv/200-csv/*.csv",
concurrency: int = 5,
model: str = "gpt-4o-mini",
) -> tuple[File, list[TableInfo | None]]:
"""Main data ingestion pipeline: download β extract β analyze β create DB."""
data_dir = await download_and_extract(csv_zip_path, search_glob)
table_infos = await extract_table_info(data_dir, model, concurrency)
database_path = "wiki_table_questions.db"
i = 0
async for csv_file in data_dir.walk():
table_info = table_infos[i]
if table_info:
ok = await create_table(csv_file, table_info, database_path)
if ok == "false":
table_infos[i] = None
else:
print(f"Skipping table creation for {csv_file} due to missing TableInfo.")
i += 1
db_file = await File.from_local(database_path)
return db_file, table_infos
# {{/docs-fragment data_ingestion}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/data_ingestion.py)
The ingestion step:
1. Downloads the dataset (a zip archive from GitHub).
2. Extracts the CSV files locally.
3. Generates table metadata (names and descriptions).
4. Creates corresponding tables in SQLite.
The Flyte task returns both the path to the database and the generated table metadata.
```
import asyncio
import fnmatch
import os
import re
import zipfile
import flyte
import pandas as pd
import requests
from flyte.io import Dir, File
from llama_index.core.llms import ChatMessage
from llama_index.core.prompts import ChatPromptTemplate
from llama_index.llms.openai import OpenAI
from pydantic import BaseModel, Field
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from utils import env
# {{docs-fragment table_info}}
class TableInfo(BaseModel):
"""Information regarding a structured table."""
table_name: str = Field(..., description="table name (underscores only, no spaces)")
table_summary: str = Field(
..., description="short, concise summary/caption of the table"
)
# {{/docs-fragment table_info}}
@env.task
async def download_and_extract(zip_path: str, search_glob: str) -> Dir:
"""Download and extract the dataset zip file if not already available."""
output_zip = "data.zip"
extract_dir = "wiki_table_questions"
if not os.path.exists(zip_path):
response = requests.get(zip_path, stream=True)
with open(output_zip, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
else:
output_zip = zip_path
print(f"Using existing file {output_zip}")
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(output_zip, "r") as zip_ref:
for member in zip_ref.namelist():
if fnmatch.fnmatch(member, search_glob):
zip_ref.extract(member, extract_dir)
remote_dir = await Dir.from_local(extract_dir)
return remote_dir
async def read_csv_file(
csv_file: File, nrows: int | None = None
) -> pd.DataFrame | None:
"""Safely download and parse a CSV file into a DataFrame."""
try:
local_csv_file = await csv_file.download()
return pd.read_csv(local_csv_file, nrows=nrows)
except Exception as e:
print(f"Error parsing {csv_file.path}: {e}")
return None
def sanitize_column_name(col_name: str) -> str:
"""Sanitize column names by replacing spaces/special chars with underscores."""
return re.sub(r"\W+", "_", col_name)
async def create_table_from_dataframe(
df: pd.DataFrame, table_name: str, engine, metadata_obj
):
"""Create a SQL table from a Pandas DataFrame."""
# Sanitize column names
sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
df = df.rename(columns=sanitized_columns)
# Define table columns based on DataFrame dtypes
columns = [
Column(col, String if dtype == "object" else Integer)
for col, dtype in zip(df.columns, df.dtypes)
]
table = Table(table_name, metadata_obj, *columns)
# Create table in database
metadata_obj.create_all(engine)
# Insert data into table
with engine.begin() as conn:
for _, row in df.iterrows():
conn.execute(table.insert().values(**row.to_dict()))
@flyte.trace
async def create_table(
csv_file: File, table_info: TableInfo, database_path: str
) -> str:
"""Safely create a table from CSV if parsing succeeds."""
df = await read_csv_file(csv_file)
if df is None:
return "false"
print(f"Creating table: {table_info.table_name}")
engine = create_engine(f"sqlite:///{database_path}")
metadata_obj = MetaData()
await create_table_from_dataframe(df, table_info.table_name, engine, metadata_obj)
return "true"
@flyte.trace
async def llm_structured_predict(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
feedback: str,
llm: OpenAI,
) -> TableInfo:
return llm.structured_predict(
TableInfo,
prompt_tmpl,
feedback=feedback,
table_str=df_str,
exclude_table_name_list=str(list(table_names)),
)
async def generate_unique_table_info(
df_str: str,
table_names: list[str],
prompt_tmpl: ChatPromptTemplate,
llm: OpenAI,
tablename_lock: asyncio.Lock,
retries: int = 3,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
last_table_name = None
for attempt in range(retries):
feedback = ""
if attempt > 0:
feedback = f"Note: '{last_table_name}' already exists. Please pick a new name not in {table_names}."
table_info = await llm_structured_predict(
df_str, table_names, prompt_tmpl, feedback, llm
)
last_table_name = table_info.table_name
async with tablename_lock:
if table_info.table_name not in table_names:
table_names.append(table_info.table_name)
return table_info
print(f"Table name {table_info.table_name} already exists, retrying...")
return None
async def process_csv_file(
csv_file: File,
table_names: list[str],
semaphore: asyncio.Semaphore,
tablename_lock: asyncio.Lock,
llm: OpenAI,
prompt_tmpl: ChatPromptTemplate,
) -> TableInfo | None:
"""Process a single CSV file to generate a unique TableInfo."""
async with semaphore:
df = await read_csv_file(csv_file, nrows=10)
if df is None:
return None
return await generate_unique_table_info(
df.to_csv(), table_names, prompt_tmpl, llm, tablename_lock
)
@env.task
async def extract_table_info(
data_dir: Dir, model: str, concurrency: int
) -> list[TableInfo | None]:
"""Extract structured table information from CSV files."""
table_names: list[str] = []
semaphore = asyncio.Semaphore(concurrency)
tablename_lock = asyncio.Lock()
llm = OpenAI(model=model)
prompt_str = """\
Provide a JSON object with the following fields:
- `table_name`: must be unique and descriptive (underscores only, no generic names).
- `table_summary`: short and concise summary of the table.
Do NOT use any of these table names: {exclude_table_name_list}
Table:
{table_str}
{feedback}
"""
prompt_tmpl = ChatPromptTemplate(
message_templates=[ChatMessage.from_str(prompt_str, role="user")]
)
tasks = [
process_csv_file(
csv_file, table_names, semaphore, tablename_lock, llm, prompt_tmpl
)
async for csv_file in data_dir.walk()
]
return await asyncio.gather(*tasks)
# {{docs-fragment data_ingestion}}
@env.task
async def data_ingestion(
csv_zip_path: str = "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob: str = "WikiTableQuestions/csv/200-csv/*.csv",
concurrency: int = 5,
model: str = "gpt-4o-mini",
) -> tuple[File, list[TableInfo | None]]:
"""Main data ingestion pipeline: download β extract β analyze β create DB."""
data_dir = await download_and_extract(csv_zip_path, search_glob)
table_infos = await extract_table_info(data_dir, model, concurrency)
database_path = "wiki_table_questions.db"
i = 0
async for csv_file in data_dir.walk():
table_info = table_infos[i]
if table_info:
ok = await create_table(csv_file, table_info, database_path)
if ok == "false":
table_infos[i] = None
else:
print(f"Skipping table creation for {csv_file} due to missing TableInfo.")
i += 1
db_file = await File.from_local(database_path)
return db_file, table_infos
# {{/docs-fragment data_ingestion}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/data_ingestion.py)
With Union artifacts (coming soon!), you'll be able to persist the ingested SQLite database as an artifact. This removes the need to rerun data ingestion in every pipeline.
## From question to SQL
Next, we define a workflow that converts natural language into executable SQL using a retrieval-augmented generation (RAG) approach.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "sqlalchemy>=2.0.0",
# "pandas>=2.0.0",
# "requests>=2.25.0",
# "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///
import asyncio
from pathlib import Path
import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
PromptTemplate,
SQLDatabase,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env
# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
"""Index a single table into vector store."""
path = f"{table_index_dir}/{table_name}"
engine = create_engine(database_uri)
def _fetch_rows():
with engine.connect() as conn:
cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
return cursor.fetchall()
result = await asyncio.to_thread(_fetch_rows)
nodes = [TextNode(text=str(tuple(row))) for row in result]
index = VectorStoreIndex(nodes)
index.set_index_id("vector_index")
index.storage_context.persist(path)
return path
@env.task
async def index_all_tables(db_file: File) -> Dir:
"""Index all tables concurrently."""
table_index_dir = "table_indices"
Path(table_index_dir).mkdir(exist_ok=True)
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
tasks = [
index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
for t in sql_database.get_usable_table_names()
]
await asyncio.gather(*tasks)
remote_dir = await Dir.from_local(table_index_dir)
return remote_dir
# {{/docs-fragment index_tables}}
@flyte.trace
async def get_table_schema_context(
table_schema_obj: SQLTableSchema,
database_uri: str,
) -> str:
"""Retrieve schema + optional description context for a single table."""
engine = create_engine(database_uri)
sql_database = SQLDatabase(engine)
table_info = sql_database.get_single_table_info(table_schema_obj.table_name)
if table_schema_obj.context_str:
table_info += f" The table description is: {table_schema_obj.context_str}"
return table_info
@flyte.trace
async def get_table_row_context(
table_schema_obj: SQLTableSchema,
local_vector_index_dir: str,
query: str,
) -> str:
"""Retrieve row-level context examples using vector search."""
storage_context = StorageContext.from_defaults(
persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
)
vector_index = load_index_from_storage(storage_context, index_id="vector_index")
vector_retriever = vector_index.as_retriever(similarity_top_k=2)
relevant_nodes = vector_retriever.retrieve(query)
if not relevant_nodes:
return ""
row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
row_context += str(node.get_content()) + "\n"
return row_context
async def process_table(
table_schema_obj: SQLTableSchema,
database_uri: str,
local_vector_index_dir: str,
query: str,
) -> str:
"""Combine schema + row context for one table."""
table_info = await get_table_schema_context(table_schema_obj, database_uri)
row_context = await get_table_row_context(
table_schema_obj, local_vector_index_dir, query
)
full_context = table_info
if row_context:
full_context += "\n" + row_context
print(f"Table Info: {full_context}")
return full_context
async def get_table_context_and_rows_str(
query: str,
database_uri: str,
table_schema_objs: list[SQLTableSchema],
vector_index_dir: Dir,
):
"""Get combined schema + row context for all tables."""
local_vector_index_dir = await vector_index_dir.download()
# run per-table work concurrently
context_strs = await asyncio.gather(
*[
process_table(t, database_uri, local_vector_index_dir, query)
for t in table_schema_objs
]
)
return "\n\n".join(context_strs)
# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
query: str,
table_infos: list[TableInfo | None],
db_file: File,
vector_index_dir: Dir,
) -> str:
"""Retrieve relevant tables and return schema context string."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
if t is not None
]
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
retrieved_schemas = obj_retriever.retrieve(query)
return await get_table_context_and_rows_str(
query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
)
# {{/docs-fragment retrieve_tables}}
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Extract SQL query from LLM response."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip()
# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
"""Generate SQL query from natural language question and table context."""
llm = OpenAI(model=model)
fmt_messages = (
PromptTemplate(
prompt,
prompt_type=PromptType.TEXT_TO_SQL,
)
.partial_format(dialect="sqlite")
.format_messages(query_str=query, schema=table_context)
)
chat_response = await llm.achat(fmt_messages)
return parse_response_to_sql(chat_response)
@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
"""Run SQL query on database and synthesize final response."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
response_synthesis_prompt = PromptTemplate(
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
llm = OpenAI(model=model)
fmt_messages = response_synthesis_prompt.format_messages(
sql_query=sql,
context_str=str(retrieved_rows),
query_str=query,
)
chat_response = await llm.achat(fmt_messages)
return chat_response.message.content
# {{/docs-fragment sql_and_response}}
# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
system_prompt: str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "
"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n"
"SQLResult: Result of the SQLQuery\n"
"Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
),
query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
model: str = "gpt-4o-mini",
) -> str:
db_file, table_infos = await data_ingestion()
vector_index_dir = await index_all_tables(db_file)
table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
sql = await generate_sql(query, table_context, model, system_prompt)
return await generate_response(query, sql, db_file, model)
# {{/docs-fragment text_to_sql}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(text_to_sql)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py)
The main `text_to_sql` task orchestrates the pipeline:
- Ingest data
- Build vector indices for each table
- Retrieve relevant tables and rows
- Generate SQL queries with an LLM
- Execute queries and synthesize answers
We use OpenAI GPT models with carefully structured prompts to maximize SQL correctness.
### Vector indexing
We index each table's rows semantically so the model can retrieve relevant examples during SQL generation.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "sqlalchemy>=2.0.0",
# "pandas>=2.0.0",
# "requests>=2.25.0",
# "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///
import asyncio
from pathlib import Path
import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
PromptTemplate,
SQLDatabase,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env
# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
"""Index a single table into vector store."""
path = f"{table_index_dir}/{table_name}"
engine = create_engine(database_uri)
def _fetch_rows():
with engine.connect() as conn:
cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
return cursor.fetchall()
result = await asyncio.to_thread(_fetch_rows)
nodes = [TextNode(text=str(tuple(row))) for row in result]
index = VectorStoreIndex(nodes)
index.set_index_id("vector_index")
index.storage_context.persist(path)
return path
@env.task
async def index_all_tables(db_file: File) -> Dir:
"""Index all tables concurrently."""
table_index_dir = "table_indices"
Path(table_index_dir).mkdir(exist_ok=True)
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
tasks = [
index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
for t in sql_database.get_usable_table_names()
]
await asyncio.gather(*tasks)
remote_dir = await Dir.from_local(table_index_dir)
return remote_dir
# {{/docs-fragment index_tables}}
@flyte.trace
async def get_table_schema_context(
table_schema_obj: SQLTableSchema,
database_uri: str,
) -> str:
"""Retrieve schema + optional description context for a single table."""
engine = create_engine(database_uri)
sql_database = SQLDatabase(engine)
table_info = sql_database.get_single_table_info(table_schema_obj.table_name)
if table_schema_obj.context_str:
table_info += f" The table description is: {table_schema_obj.context_str}"
return table_info
@flyte.trace
async def get_table_row_context(
table_schema_obj: SQLTableSchema,
local_vector_index_dir: str,
query: str,
) -> str:
"""Retrieve row-level context examples using vector search."""
storage_context = StorageContext.from_defaults(
persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
)
vector_index = load_index_from_storage(storage_context, index_id="vector_index")
vector_retriever = vector_index.as_retriever(similarity_top_k=2)
relevant_nodes = vector_retriever.retrieve(query)
if not relevant_nodes:
return ""
row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
row_context += str(node.get_content()) + "\n"
return row_context
async def process_table(
table_schema_obj: SQLTableSchema,
database_uri: str,
local_vector_index_dir: str,
query: str,
) -> str:
"""Combine schema + row context for one table."""
table_info = await get_table_schema_context(table_schema_obj, database_uri)
row_context = await get_table_row_context(
table_schema_obj, local_vector_index_dir, query
)
full_context = table_info
if row_context:
full_context += "\n" + row_context
print(f"Table Info: {full_context}")
return full_context
async def get_table_context_and_rows_str(
query: str,
database_uri: str,
table_schema_objs: list[SQLTableSchema],
vector_index_dir: Dir,
):
"""Get combined schema + row context for all tables."""
local_vector_index_dir = await vector_index_dir.download()
# run per-table work concurrently
context_strs = await asyncio.gather(
*[
process_table(t, database_uri, local_vector_index_dir, query)
for t in table_schema_objs
]
)
return "\n\n".join(context_strs)
# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
query: str,
table_infos: list[TableInfo | None],
db_file: File,
vector_index_dir: Dir,
) -> str:
"""Retrieve relevant tables and return schema context string."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
if t is not None
]
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
retrieved_schemas = obj_retriever.retrieve(query)
return await get_table_context_and_rows_str(
query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
)
# {{/docs-fragment retrieve_tables}}
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Extract SQL query from LLM response."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip()
# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
"""Generate SQL query from natural language question and table context."""
llm = OpenAI(model=model)
fmt_messages = (
PromptTemplate(
prompt,
prompt_type=PromptType.TEXT_TO_SQL,
)
.partial_format(dialect="sqlite")
.format_messages(query_str=query, schema=table_context)
)
chat_response = await llm.achat(fmt_messages)
return parse_response_to_sql(chat_response)
@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
"""Run SQL query on database and synthesize final response."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
response_synthesis_prompt = PromptTemplate(
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
llm = OpenAI(model=model)
fmt_messages = response_synthesis_prompt.format_messages(
sql_query=sql,
context_str=str(retrieved_rows),
query_str=query,
)
chat_response = await llm.achat(fmt_messages)
return chat_response.message.content
# {{/docs-fragment sql_and_response}}
# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
system_prompt: str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "
"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n"
"SQLResult: Result of the SQLQuery\n"
"Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
),
query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
model: str = "gpt-4o-mini",
) -> str:
db_file, table_infos = await data_ingestion()
vector_index_dir = await index_all_tables(db_file)
table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
sql = await generate_sql(query, table_context, model, system_prompt)
return await generate_response(query, sql, db_file, model)
# {{/docs-fragment text_to_sql}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(text_to_sql)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py)
Each row becomes a text node stored in LlamaIndexβs `VectorStoreIndex`. This lets the system pull semantically similar rows when handling queries.
### Table retrieval and context building
We then retrieve the most relevant tables for a given query and build rich context that combines schema information with sample rows.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "sqlalchemy>=2.0.0",
# "pandas>=2.0.0",
# "requests>=2.25.0",
# "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///
import asyncio
from pathlib import Path
import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
PromptTemplate,
SQLDatabase,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env
# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
"""Index a single table into vector store."""
path = f"{table_index_dir}/{table_name}"
engine = create_engine(database_uri)
def _fetch_rows():
with engine.connect() as conn:
cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
return cursor.fetchall()
result = await asyncio.to_thread(_fetch_rows)
nodes = [TextNode(text=str(tuple(row))) for row in result]
index = VectorStoreIndex(nodes)
index.set_index_id("vector_index")
index.storage_context.persist(path)
return path
@env.task
async def index_all_tables(db_file: File) -> Dir:
"""Index all tables concurrently."""
table_index_dir = "table_indices"
Path(table_index_dir).mkdir(exist_ok=True)
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
tasks = [
index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
for t in sql_database.get_usable_table_names()
]
await asyncio.gather(*tasks)
remote_dir = await Dir.from_local(table_index_dir)
return remote_dir
# {{/docs-fragment index_tables}}
@flyte.trace
async def get_table_schema_context(
table_schema_obj: SQLTableSchema,
database_uri: str,
) -> str:
"""Retrieve schema + optional description context for a single table."""
engine = create_engine(database_uri)
sql_database = SQLDatabase(engine)
table_info = sql_database.get_single_table_info(table_schema_obj.table_name)
if table_schema_obj.context_str:
table_info += f" The table description is: {table_schema_obj.context_str}"
return table_info
@flyte.trace
async def get_table_row_context(
table_schema_obj: SQLTableSchema,
local_vector_index_dir: str,
query: str,
) -> str:
"""Retrieve row-level context examples using vector search."""
storage_context = StorageContext.from_defaults(
persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
)
vector_index = load_index_from_storage(storage_context, index_id="vector_index")
vector_retriever = vector_index.as_retriever(similarity_top_k=2)
relevant_nodes = vector_retriever.retrieve(query)
if not relevant_nodes:
return ""
row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
row_context += str(node.get_content()) + "\n"
return row_context
async def process_table(
table_schema_obj: SQLTableSchema,
database_uri: str,
local_vector_index_dir: str,
query: str,
) -> str:
"""Combine schema + row context for one table."""
table_info = await get_table_schema_context(table_schema_obj, database_uri)
row_context = await get_table_row_context(
table_schema_obj, local_vector_index_dir, query
)
full_context = table_info
if row_context:
full_context += "\n" + row_context
print(f"Table Info: {full_context}")
return full_context
async def get_table_context_and_rows_str(
query: str,
database_uri: str,
table_schema_objs: list[SQLTableSchema],
vector_index_dir: Dir,
):
"""Get combined schema + row context for all tables."""
local_vector_index_dir = await vector_index_dir.download()
# run per-table work concurrently
context_strs = await asyncio.gather(
*[
process_table(t, database_uri, local_vector_index_dir, query)
for t in table_schema_objs
]
)
return "\n\n".join(context_strs)
# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
query: str,
table_infos: list[TableInfo | None],
db_file: File,
vector_index_dir: Dir,
) -> str:
"""Retrieve relevant tables and return schema context string."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
if t is not None
]
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
retrieved_schemas = obj_retriever.retrieve(query)
return await get_table_context_and_rows_str(
query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
)
# {{/docs-fragment retrieve_tables}}
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Extract SQL query from LLM response."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip()
# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
"""Generate SQL query from natural language question and table context."""
llm = OpenAI(model=model)
fmt_messages = (
PromptTemplate(
prompt,
prompt_type=PromptType.TEXT_TO_SQL,
)
.partial_format(dialect="sqlite")
.format_messages(query_str=query, schema=table_context)
)
chat_response = await llm.achat(fmt_messages)
return parse_response_to_sql(chat_response)
@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
"""Run SQL query on database and synthesize final response."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
response_synthesis_prompt = PromptTemplate(
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
llm = OpenAI(model=model)
fmt_messages = response_synthesis_prompt.format_messages(
sql_query=sql,
context_str=str(retrieved_rows),
query_str=query,
)
chat_response = await llm.achat(fmt_messages)
return chat_response.message.content
# {{/docs-fragment sql_and_response}}
# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
system_prompt: str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "
"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n"
"SQLResult: Result of the SQLQuery\n"
"Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
),
query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
model: str = "gpt-4o-mini",
) -> str:
db_file, table_infos = await data_ingestion()
vector_index_dir = await index_all_tables(db_file)
table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
sql = await generate_sql(query, table_context, model, system_prompt)
return await generate_response(query, sql, db_file, model)
# {{/docs-fragment text_to_sql}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(text_to_sql)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py)
The retriever selects tables via semantic similarity, then attaches their schema and example rows. This context grounds the model's SQL generation in the database's actual structure and content.
### SQL generation and response synthesis
Finally, we generate SQL queries and produce natural language answers.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "sqlalchemy>=2.0.0",
# "pandas>=2.0.0",
# "requests>=2.25.0",
# "pydantic>=2.0.0",
# ]
# main = "text_to_sql"
# params = ""
# ///
import asyncio
from pathlib import Path
import flyte
from data_ingestion import TableInfo, data_ingestion
from flyte.io import Dir, File
from llama_index.core import (
PromptTemplate,
SQLDatabase,
StorageContext,
VectorStoreIndex,
load_index_from_storage,
)
from llama_index.core.llms import ChatResponse
from llama_index.core.objects import ObjectIndex, SQLTableNodeMapping, SQLTableSchema
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.schema import TextNode
from llama_index.llms.openai import OpenAI
from sqlalchemy import create_engine, text
from utils import env
# {{docs-fragment index_tables}}
@flyte.trace
async def index_table(table_name: str, table_index_dir: str, database_uri: str) -> str:
"""Index a single table into vector store."""
path = f"{table_index_dir}/{table_name}"
engine = create_engine(database_uri)
def _fetch_rows():
with engine.connect() as conn:
cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
return cursor.fetchall()
result = await asyncio.to_thread(_fetch_rows)
nodes = [TextNode(text=str(tuple(row))) for row in result]
index = VectorStoreIndex(nodes)
index.set_index_id("vector_index")
index.storage_context.persist(path)
return path
@env.task
async def index_all_tables(db_file: File) -> Dir:
"""Index all tables concurrently."""
table_index_dir = "table_indices"
Path(table_index_dir).mkdir(exist_ok=True)
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
tasks = [
index_table(t, table_index_dir, "sqlite:///local_db.sqlite")
for t in sql_database.get_usable_table_names()
]
await asyncio.gather(*tasks)
remote_dir = await Dir.from_local(table_index_dir)
return remote_dir
# {{/docs-fragment index_tables}}
@flyte.trace
async def get_table_schema_context(
table_schema_obj: SQLTableSchema,
database_uri: str,
) -> str:
"""Retrieve schema + optional description context for a single table."""
engine = create_engine(database_uri)
sql_database = SQLDatabase(engine)
table_info = sql_database.get_single_table_info(table_schema_obj.table_name)
if table_schema_obj.context_str:
table_info += f" The table description is: {table_schema_obj.context_str}"
return table_info
@flyte.trace
async def get_table_row_context(
table_schema_obj: SQLTableSchema,
local_vector_index_dir: str,
query: str,
) -> str:
"""Retrieve row-level context examples using vector search."""
storage_context = StorageContext.from_defaults(
persist_dir=str(f"{local_vector_index_dir}/{table_schema_obj.table_name}")
)
vector_index = load_index_from_storage(storage_context, index_id="vector_index")
vector_retriever = vector_index.as_retriever(similarity_top_k=2)
relevant_nodes = vector_retriever.retrieve(query)
if not relevant_nodes:
return ""
row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
for node in relevant_nodes:
row_context += str(node.get_content()) + "\n"
return row_context
async def process_table(
table_schema_obj: SQLTableSchema,
database_uri: str,
local_vector_index_dir: str,
query: str,
) -> str:
"""Combine schema + row context for one table."""
table_info = await get_table_schema_context(table_schema_obj, database_uri)
row_context = await get_table_row_context(
table_schema_obj, local_vector_index_dir, query
)
full_context = table_info
if row_context:
full_context += "\n" + row_context
print(f"Table Info: {full_context}")
return full_context
async def get_table_context_and_rows_str(
query: str,
database_uri: str,
table_schema_objs: list[SQLTableSchema],
vector_index_dir: Dir,
):
"""Get combined schema + row context for all tables."""
local_vector_index_dir = await vector_index_dir.download()
# run per-table work concurrently
context_strs = await asyncio.gather(
*[
process_table(t, database_uri, local_vector_index_dir, query)
for t in table_schema_objs
]
)
return "\n\n".join(context_strs)
# {{docs-fragment retrieve_tables}}
@env.task
async def retrieve_tables(
query: str,
table_infos: list[TableInfo | None],
db_file: File,
vector_index_dir: Dir,
) -> str:
"""Retrieve relevant tables and return schema context string."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
for t in table_infos
if t is not None
]
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)
retrieved_schemas = obj_retriever.retrieve(query)
return await get_table_context_and_rows_str(
query, "sqlite:///local_db.sqlite", retrieved_schemas, vector_index_dir
)
# {{/docs-fragment retrieve_tables}}
def parse_response_to_sql(chat_response: ChatResponse) -> str:
"""Extract SQL query from LLM response."""
response = chat_response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip()
# {{docs-fragment sql_and_response}}
@env.task
async def generate_sql(query: str, table_context: str, model: str, prompt: str) -> str:
"""Generate SQL query from natural language question and table context."""
llm = OpenAI(model=model)
fmt_messages = (
PromptTemplate(
prompt,
prompt_type=PromptType.TEXT_TO_SQL,
)
.partial_format(dialect="sqlite")
.format_messages(query_str=query, schema=table_context)
)
chat_response = await llm.achat(fmt_messages)
return parse_response_to_sql(chat_response)
@env.task
async def generate_response(query: str, sql: str, db_file: File, model: str) -> str:
"""Run SQL query on database and synthesize final response."""
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
response_synthesis_prompt = PromptTemplate(
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
llm = OpenAI(model=model)
fmt_messages = response_synthesis_prompt.format_messages(
sql_query=sql,
context_str=str(retrieved_rows),
query_str=query,
)
chat_response = await llm.achat(fmt_messages)
return chat_response.message.content
# {{/docs-fragment sql_and_response}}
# {{docs-fragment text_to_sql}}
@env.task
async def text_to_sql(
system_prompt: str = (
"Given an input question, first create a syntactically correct {dialect} "
"query to run, then look at the results of the query and return the answer. "
"You can order the results by a relevant column to return the most "
"interesting examples in the database.\n\n"
"Never query for all the columns from a specific table, only ask for a "
"few relevant columns given the question.\n\n"
"Pay attention to use only the column names that you can see in the schema "
"description. "
"Be careful to not query for columns that do not exist. "
"Pay attention to which column is in which table. "
"Also, qualify column names with the table name when needed. "
"You are required to use the following format, each taking one line:\n\n"
"Question: Question here\n"
"SQLQuery: SQL Query to run\n"
"SQLResult: Result of the SQLQuery\n"
"Answer: Final answer here\n\n"
"Only use tables listed below.\n"
"{schema}\n\n"
"Question: {query_str}\n"
"SQLQuery: "
),
query: str = "What was the year that The Notorious BIG was signed to Bad Boy?",
model: str = "gpt-4o-mini",
) -> str:
db_file, table_infos = await data_ingestion()
vector_index_dir = await index_all_tables(db_file)
table_context = await retrieve_tables(query, table_infos, db_file, vector_index_dir)
sql = await generate_sql(query, table_context, model, system_prompt)
return await generate_response(query, sql, db_file, model)
# {{/docs-fragment text_to_sql}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(text_to_sql)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/text_to_sql.py)
The SQL generation prompt includes schema, example rows, and formatting rules. After execution, the system returns a final answer.
At this point, we have an end-to-end Text-to-SQL pipeline: natural language questions go in, SQL queries run, and answers come back. To make this workflow production-ready, we leveraged several Flyte 2 capabilities. Caching ensures that repeated steps, like table ingestion or vector indexing, donβt need to rerun unnecessarily, saving time and compute. Containerization provides consistent, reproducible execution across environments, making it easier to scale and deploy. Observability features let us track every step of the pipeline, monitor performance, and debug issues quickly.
While the pipeline works end-to-end, to get a pulse on how it performs across multiple prompts and to gradually improve performance, we can start experimenting with prompt tuning.
Two things help make this process meaningful:
- **A clean evaluation dataset** - so we can measure accuracy against trusted ground truth.
- **A systematic evaluation loop** - so we can see whether prompt changes or other adjustments actually help.
With these in place, the next step is to build a "golden" QA dataset that will guide iterative prompt optimization.
## Building the QA dataset
> [!NOTE]
> The WikiTableQuestions dataset already includes questionβanswer pairs, available in its [GitHub repository](https://github.com/ppasupat/WikiTableQuestions/tree/master/data). To use them for this workflow, you'll need to adapt the data into the required format, but the raw material is there for you to build on.
We generate a dataset of natural language questions paired with executable SQL queries. This dataset acts as the benchmark for prompt tuning and evaluation.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py)
The pipeline does the following:
- Schema extraction β pull full database schemas, including table names, columns, and sample rows.
- QuestionβSQL generation β use an LLM to produce natural language questions with matching SQL queries.
- Validation β run each query against the database, filter out invalid results, and also remove results that aren't relevant.
- Final export β store the clean, validated pairs in CSV format for downstream use.
### Schema extraction and chunking
We break schemas into smaller chunks to cover all tables evenly. This avoids overfitting to a subset of tables and ensures broad coverage across the dataset.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py)
### Question and SQL generation
Using structured prompts, we ask an LLM to generate realistic questions users might ask, then pair them with syntactically valid SQL queries. Deduplication ensures diversity across queries.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py)
### Validation and quality control
Each generated SQL query runs against the database, and another LLM double-checks that the result matches the intent of the natural language question.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# "pydantic>=2.0.0",
# ]
# main = "build_eval_dataset"
# params = ""
# ///
import sqlite3
import flyte
import pandas as pd
from data_ingestion import data_ingestion
from flyte.io import File
from llama_index.core import PromptTemplate
from llama_index.llms.openai import OpenAI
from utils import env
from pydantic import BaseModel
class QAItem(BaseModel):
question: str
sql: str
class QAList(BaseModel):
items: list[QAItem]
# {{docs-fragment get_and_split_schema}}
@env.task
async def get_and_split_schema(db_file: File, tables_per_chunk: int) -> list[str]:
"""
Download the SQLite DB, extract schema info (columns + sample rows),
then split it into chunks with up to `tables_per_chunk` tables each.
"""
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
tables = cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table';"
).fetchall()
schema_blocks = []
for table in tables:
table_name = table[0]
# columns
cursor.execute(f"PRAGMA table_info({table_name});")
columns = [col[1] for col in cursor.fetchall()]
block = f"Table: {table_name}({', '.join(columns)})"
# sample rows
cursor.execute(f"SELECT * FROM {table_name} LIMIT 10;")
rows = cursor.fetchall()
if rows:
block += "\nSample rows:\n"
for row in rows:
block += f"{row}\n"
schema_blocks.append(block)
conn.close()
chunks = []
current_chunk = []
for block in schema_blocks:
current_chunk.append(block)
if len(current_chunk) >= tables_per_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks
# {{/docs-fragment get_and_split_schema}}
# {{docs-fragment generate_questions_and_sql}}
@flyte.trace
async def generate_questions_and_sql(
schema: str, num_samples: int, batch_size: int
) -> QAList:
llm = OpenAI(model="gpt-4.1")
prompt_tmpl = PromptTemplate(
"""Prompt: You are helping build a Text-to-SQL dataset.
Here is the database schema:
{schema}
Generate {num} natural language questions a user might ask about this database.
For each question, also provide the correct SQL query.
Reasoning process (you must follow this internally):
- Given an input question, first create a syntactically correct {dialect} SQL query.
- Never use SELECT *; only include the relevant columns.
- Use only columns/tables from the schema. Qualify column names when ambiguous.
- You may order results by a meaningful column to make the query more useful.
- Be careful not to add unnecessary columns.
- Use filters, aggregations, joins, grouping, and subqueries when relevant.
Final Output:
Return only a JSON object with one field:
- "items": a list of {num} objects, each with:
- "question": the natural language question
- "sql": the corresponding SQL query
"""
)
all_items: list[QAItem] = []
# batch generation
for start in range(0, num_samples, batch_size):
current_num = min(batch_size, num_samples - start)
response = llm.structured_predict(
QAList,
prompt_tmpl,
schema=schema,
num=current_num,
)
all_items.extend(response.items)
# deduplicate
seen = set()
unique_items: list[QAItem] = []
for item in all_items:
key = (item.question.strip().lower(), item.sql.strip().lower())
if key not in seen:
seen.add(key)
unique_items.append(item)
return QAList(items=unique_items[:num_samples])
# {{/docs-fragment generate_questions_and_sql}}
@flyte.trace
async def llm_validate_batch(pairs: list[dict[str, str]]) -> list[str]:
"""Validate a batch of question/sql/result dicts using one LLM call."""
batch_prompt = """You are validating the correctness of SQL query results against the question.
For each example, answer only "True" (correct) or "False" (incorrect).
Output one answer per line, in the same order as the examples.
---
"""
for i, pair in enumerate(pairs, start=1):
batch_prompt += f"""
Example {i}:
Question:
{pair['question']}
SQL:
{pair['sql']}
Result:
{pair['rows']}
---
"""
llm = OpenAI(model="gpt-4.1")
resp = await llm.acomplete(batch_prompt)
# Expect exactly one True/False per example
results = [
line.strip()
for line in resp.text.splitlines()
if line.strip() in ("True", "False")
]
return results
# {{docs-fragment validate_sql}}
@env.task
async def validate_sql(
db_file: File, question_sql_pairs: QAList, batch_size: int
) -> list[dict[str, str]]:
await db_file.download(local_path="local_db.sqlite")
conn = sqlite3.connect("local_db.sqlite")
cursor = conn.cursor()
qa_data = []
batch = []
for pair in question_sql_pairs.items:
q, sql = pair.question, pair.sql
try:
cursor.execute(sql)
rows = cursor.fetchall()
batch.append({"question": q, "sql": sql, "rows": str(rows)})
# process when batch is full
if len(batch) == batch_size:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
batch = []
except Exception as e:
print(f"Skipping invalid SQL: {sql} ({e})")
# process leftover batch
if batch:
results = await llm_validate_batch(batch)
for pair, is_valid in zip(batch, results):
if is_valid == "True":
qa_data.append(
{
"input": pair["question"],
"sql": pair["sql"],
"target": pair["rows"],
}
)
else:
print(f"Filtered out incorrect result for: {pair['question']}")
conn.close()
return qa_data
# {{/docs-fragment validate_sql}}
@flyte.trace
async def save_to_csv(qa_data: list[dict]) -> File:
df = pd.DataFrame(qa_data, columns=["input", "target", "sql"])
csv_file = "qa_dataset.csv"
df.to_csv(csv_file, index=False)
return await File.from_local(csv_file)
# {{docs-fragment build_eval_dataset}}
@env.task
async def build_eval_dataset(
num_samples: int = 300, batch_size: int = 30, tables_per_chunk: int = 3
) -> File:
db_file, _ = await data_ingestion()
schema_chunks = await get_and_split_schema(db_file, tables_per_chunk)
per_chunk_samples = max(1, num_samples // len(schema_chunks))
final_qa_data = []
for chunk in schema_chunks:
qa_list = await generate_questions_and_sql(
schema=chunk,
num_samples=per_chunk_samples,
batch_size=batch_size,
)
qa_data = await validate_sql(db_file, qa_list, batch_size)
final_qa_data.extend(qa_data)
csv_file = await save_to_csv(final_qa_data)
return csv_file
# {{/docs-fragment build_eval_dataset}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(build_eval_dataset)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/create_qa_dataset.py)
Even with automated checks, human review remains critical. Since this dataset serves as the ground truth, mislabeled pairs can distort evaluation. For production use, always invest in human-in-the-loop review.
Support for human-in-the-loop pipelines is coming soon in Flyte 2!
## Optimizing prompts
With the QA dataset in place, we can turn to prompt optimization. The idea: start from a baseline prompt, generate new variants, and measure whether accuracy improves.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "sqlalchemy>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env
CSS = """
"""
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Val/Test split
df_renamed = df.rename(columns={"input": "question", "target": "answer"})
n = len(df_renamed)
split = n // 2
df_val = df_renamed.iloc[:split]
df_test = df_renamed.iloc[split:]
return df_val, df_test
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
if retrieved_rows:
# Get the structured result and stringify
return str(retrieved_rows[0].node.metadata["result"])
return ""
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
db_file: File,
table_infos: list[TableInfo | None],
vector_index_dir: Dir,
) -> dict:
# Generate response from target model
table_context = await retrieve_tables(
question, table_infos, db_file, vector_index_dir
)
sql = await generate_sql(
question,
table_context,
target_model_config.model_name,
target_model_config.prompt,
)
sql = sql.replace("sql\n", "")
try:
response = await generate_response(db_file, sql)
except Exception as e:
print(f"Failed to generate response for question {question}: {e}")
response = None
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
query_str=question,
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"sql": sql,
"is_correct": verdict_clean == "true",
}
async def run_grouped_task(
i,
index,
question,
answer,
sql,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
db_file,
table_infos,
vector_index_dir,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
db_file,
table_infos,
vector_index_dir,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
ground_truth_csv: File | str = "/root/ground_truth.csv",
db_config: DatabaseConfig = DatabaseConfig(
csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob="WikiTableQuestions/csv/200-csv/*.csv",
concurrency=5,
model="gpt-4o-mini",
),
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""Given an input question, create a syntactically correct {dialect} query to run.
Schema:
{schema}
Question: {query_str}
SQL query to run:
""",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.
- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".
Question:
{query_str}
Ground Truth:
{answer}
Model Response:
{response}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
{prompt_scores_str}
Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.
artists(id, name)
albums(id, title, artist_id, release_year)
How many albums did The Beatles release?
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 5,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
ground_truth_csv = await File.from_local(ground_truth_csv)
df_val, df_test = await data_prep(ground_truth_csv)
best_prompt, val_accuracy = await prompt_optimizer(
df_val,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
db_config,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
return {
"best_prompt": best_prompt,
"validation_accuracy": val_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py)
### Evaluation pipeline
We evaluate each prompt variant against the golden dataset, split into validation and test sets, and record accuracy metrics in real time.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "sqlalchemy>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env
CSS = """
"""
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Val/Test split
df_renamed = df.rename(columns={"input": "question", "target": "answer"})
n = len(df_renamed)
split = n // 2
df_val = df_renamed.iloc[:split]
df_test = df_renamed.iloc[split:]
return df_val, df_test
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
if retrieved_rows:
# Get the structured result and stringify
return str(retrieved_rows[0].node.metadata["result"])
return ""
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
db_file: File,
table_infos: list[TableInfo | None],
vector_index_dir: Dir,
) -> dict:
# Generate response from target model
table_context = await retrieve_tables(
question, table_infos, db_file, vector_index_dir
)
sql = await generate_sql(
question,
table_context,
target_model_config.model_name,
target_model_config.prompt,
)
sql = sql.replace("sql\n", "")
try:
response = await generate_response(db_file, sql)
except Exception as e:
print(f"Failed to generate response for question {question}: {e}")
response = None
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
query_str=question,
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"sql": sql,
"is_correct": verdict_clean == "true",
}
async def run_grouped_task(
i,
index,
question,
answer,
sql,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
db_file,
table_infos,
vector_index_dir,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
db_file,
table_infos,
vector_index_dir,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
ground_truth_csv: File | str = "/root/ground_truth.csv",
db_config: DatabaseConfig = DatabaseConfig(
csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob="WikiTableQuestions/csv/200-csv/*.csv",
concurrency=5,
model="gpt-4o-mini",
),
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""Given an input question, create a syntactically correct {dialect} query to run.
Schema:
{schema}
Question: {query_str}
SQL query to run:
""",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.
- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".
Question:
{query_str}
Ground Truth:
{answer}
Model Response:
{response}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
{prompt_scores_str}
Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.
artists(id, name)
albums(id, title, artist_id, release_year)
How many albums did The Beatles release?
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 5,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
ground_truth_csv = await File.from_local(ground_truth_csv)
df_val, df_test = await data_prep(ground_truth_csv)
best_prompt, val_accuracy = await prompt_optimizer(
df_val,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
db_config,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
return {
"best_prompt": best_prompt,
"validation_accuracy": val_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py)
Here's how prompt accuracy evolves over time, as shown in the UI report:

### Iterative optimization
An optimizer LLM proposes new prompts by analyzing patterns in successful and failed generations. Each candidate runs through the evaluation loop, and we select the best performer.
```
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "pandas>=2.0.0",
# "sqlalchemy>=2.0.0",
# "llama-index-core>=0.11.0",
# "llama-index-llms-openai>=0.2.0",
# ]
# main = "auto_prompt_engineering"
# params = ""
# ///
import asyncio
import html
import os
import re
from dataclasses import dataclass
from typing import Optional, Union
import flyte
import flyte.report
import pandas as pd
from data_ingestion import TableInfo
from flyte.io import Dir, File
from llama_index.core import SQLDatabase
from llama_index.core.retrievers import SQLRetriever
from sqlalchemy import create_engine
from text_to_sql import data_ingestion, generate_sql, index_all_tables, retrieve_tables
from utils import env
CSS = """
"""
@env.task
async def data_prep(csv_file: File | str) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Q&A data from a public Google Sheet CSV export URL and split into val/test DataFrames.
The sheet should have columns: 'input' and 'target'.
"""
df = pd.read_csv(
await csv_file.download() if isinstance(csv_file, File) else csv_file
)
if "input" not in df.columns or "target" not in df.columns:
raise ValueError("Sheet must contain 'input' and 'target' columns.")
# Shuffle rows
df = df.sample(frac=1, random_state=1234).reset_index(drop=True)
# Val/Test split
df_renamed = df.rename(columns={"input": "question", "target": "answer"})
n = len(df_renamed)
split = n // 2
df_val = df_renamed.iloc[:split]
df_test = df_renamed.iloc[split:]
return df_val, df_test
@dataclass
class ModelConfig:
model_name: str
hosted_model_uri: Optional[str] = None
temperature: float = 0.0
max_tokens: Optional[int] = 1000
timeout: int = 600
prompt: str = ""
@flyte.trace
async def call_model(
model_config: ModelConfig,
messages: list[dict[str, str]],
) -> str:
from litellm import acompletion
response = await acompletion(
model=model_config.model_name,
api_base=model_config.hosted_model_uri,
messages=messages,
temperature=model_config.temperature,
timeout=model_config.timeout,
max_tokens=model_config.max_tokens,
)
return response.choices[0].message["content"]
@flyte.trace
async def generate_response(db_file: File, sql: str) -> str:
await db_file.download(local_path="local_db.sqlite")
engine = create_engine("sqlite:///local_db.sqlite")
sql_database = SQLDatabase(engine)
sql_retriever = SQLRetriever(sql_database)
retrieved_rows = sql_retriever.retrieve(sql)
if retrieved_rows:
# Get the structured result and stringify
return str(retrieved_rows[0].node.metadata["result"])
return ""
async def generate_and_review(
index: int,
question: str,
answer: str,
target_model_config: ModelConfig,
review_model_config: ModelConfig,
db_file: File,
table_infos: list[TableInfo | None],
vector_index_dir: Dir,
) -> dict:
# Generate response from target model
table_context = await retrieve_tables(
question, table_infos, db_file, vector_index_dir
)
sql = await generate_sql(
question,
table_context,
target_model_config.model_name,
target_model_config.prompt,
)
sql = sql.replace("sql\n", "")
try:
response = await generate_response(db_file, sql)
except Exception as e:
print(f"Failed to generate response for question {question}: {e}")
response = None
# Format review prompt with response + answer
review_messages = [
{
"role": "system",
"content": review_model_config.prompt.format(
query_str=question,
response=response,
answer=answer,
),
}
]
verdict = await call_model(review_model_config, review_messages)
# Normalize verdict
verdict_clean = verdict.strip().lower()
if verdict_clean not in {"true", "false"}:
verdict_clean = "not sure"
return {
"index": index,
"model_response": response,
"sql": sql,
"is_correct": verdict_clean == "true",
}
async def run_grouped_task(
i,
index,
question,
answer,
sql,
semaphore,
target_model_config,
review_model_config,
counter,
counter_lock,
db_file,
table_infos,
vector_index_dir,
):
async with semaphore:
with flyte.group(name=f"row-{i}"):
result = await generate_and_review(
index,
question,
answer,
target_model_config,
review_model_config,
db_file,
table_infos,
vector_index_dir,
)
async with counter_lock:
# Update counters
counter["processed"] += 1
if result["is_correct"]:
counter["correct"] += 1
correct_html = "β Yes"
else:
correct_html = "β No"
# Calculate accuracy
accuracy_pct = (counter["correct"] / counter["processed"]) * 100
# Update chart
await flyte.report.log.aio(
f"",
do_flush=True,
)
# Add row to table
await flyte.report.log.aio(
f"""
""",
do_flush=True,
)
return best_result.prompt, best_result.accuracy
# {{/docs-fragment prompt_optimizer}}
async def _log_prompt_row(prompt: str, accuracy: float):
"""Helper to log a single prompt/accuracy row to Flyte report."""
pct = accuracy * 100
if pct > 80:
color = "linear-gradient(90deg, #4CAF50, #81C784)"
elif pct > 60:
color = "linear-gradient(90deg, #FFC107, #FFD54F)"
else:
color = "linear-gradient(90deg, #F44336, #E57373)"
await flyte.report.log.aio(
f"""
{html.escape(prompt)}
{pct:.1f}%
""",
do_flush=True,
)
# {{docs-fragment auto_prompt_engineering}}
@env.task
async def auto_prompt_engineering(
ground_truth_csv: File | str = "/root/ground_truth.csv",
db_config: DatabaseConfig = DatabaseConfig(
csv_zip_path="https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip",
search_glob="WikiTableQuestions/csv/200-csv/*.csv",
concurrency=5,
model="gpt-4o-mini",
),
target_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1-mini",
hosted_model_uri=None,
prompt="""Given an input question, create a syntactically correct {dialect} query to run.
Schema:
{schema}
Question: {query_str}
SQL query to run:
""",
max_tokens=10000,
),
review_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
prompt="""Your job is to determine whether the model's response is correct compared to the ground truth taking into account the context of the question.
Both answers were generated by running SQL queries on the same database.
- If the model's response contains all of the ground truth values, and any additional information is harmless (e.g., extra columns or metadata), output "True".
- If it adds incorrect or unrelated rows, or omits required values, output "False".
Question:
{query_str}
Ground Truth:
{answer}
Model Response:
{response}
""",
),
optimizer_model_config: ModelConfig = ModelConfig(
model_name="gpt-4.1",
hosted_model_uri=None,
temperature=0.7,
max_tokens=None,
prompt="""
I have some prompts along with their corresponding accuracies.
The prompts are arranged in ascending order based on their accuracy, where higher accuracy indicates better quality.
{prompt_scores_str}
Each prompt was used to translate a natural-language question into a SQL query against a provided database schema.
artists(id, name)
albums(id, title, artist_id, release_year)
How many albums did The Beatles release?
SELECT COUNT(*) FROM albums a JOIN artists r ON a.artist_id = r.id WHERE r.name = 'The Beatles';
Write a new prompt that will achieve an accuracy as high as possible and that is different from the old ones.
- It is very important that the new prompt is distinct from ALL the old ones!
- Ensure that you analyse the prompts with a high accuracy and reuse the patterns that worked in the past.
- Ensure that you analyse the prompts with a low accuracy and avoid the patterns that didn't work in the past.
- Think out loud before creating the prompt. Describe what has worked in the past and what hasn't. Only then create the new prompt.
- Use all available information like prompt length, formal/informal use of language, etc. for your analysis.
- Be creative, try out different ways of prompting the model. You may even come up with hypothetical scenarios that might improve the accuracy.
- You are generating a system prompt. Always use three placeholders for each prompt: dialect, schema, query_str.
- Write your new prompt in double square brackets. Use only plain text for the prompt text and do not add any markdown (i.e. no hashtags, backticks, quotes, etc).
""",
),
max_iterations: int = 5,
concurrency: int = 10,
) -> dict[str, Union[str, float]]:
if isinstance(ground_truth_csv, str) and os.path.isfile(ground_truth_csv):
ground_truth_csv = await File.from_local(ground_truth_csv)
df_val, df_test = await data_prep(ground_truth_csv)
best_prompt, val_accuracy = await prompt_optimizer(
df_val,
target_model_config,
review_model_config,
optimizer_model_config,
max_iterations,
concurrency,
db_config,
)
with flyte.group(name="test_data_evaluation"):
baseline_test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
target_model_config.prompt = best_prompt
test_accuracy = await evaluate_prompt(
df_test,
target_model_config,
review_model_config,
concurrency,
db_config,
)
return {
"best_prompt": best_prompt,
"validation_accuracy": val_accuracy,
"baseline_test_accuracy": baseline_test_accuracy,
"test_accuracy": test_accuracy,
}
# {{/docs-fragment auto_prompt_engineering}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(auto_prompt_engineering)
print(run.url)
run.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/text_to_sql/optimizer.py)
On paper, this creates a continuous improvement cycle: baseline β new variants β measured gains.
## Run it
To create the QA dataset:
```
python create_qa_dataset.py
```
To run the prompt optimization loop:
```
python optimizer.py
```
## What we observed
Prompt optimization didn't consistently lift SQL accuracy in this workflow. Accuracy plateaued near the baseline. But the process surfaced valuable lessons about what matters when building LLM-powered systems on real infrastructure.
- **Schema clarity matters**: CSV ingestion produced tables with overlapping names, creating ambiguity. This showed how schema design and metadata hygiene directly affect downstream evaluation.
- **Ground truth needs trust**: Because the dataset came from LLM outputs, noise remained even after filtering. Human review proved essential. Golden datasets need deliberate curation, not just automation.
- **Optimization needs context**: The optimizer couldn't βseeβ which examples failed, limiting its ability to improve. Feeding failures directly risks overfitting. A structured way to capture and reuse evaluation signals is the right long-term path.
Sometimes prompt tweaks alone can lift accuracy, but other times the real bottleneck lives in the data, the schema, or the evaluation loop. The lesson isn't "prompt optimization doesn't work", but that its impact depends on the system around it. Accuracy improves most reliably when prompts evolve alongside clean data, trusted evaluation, and observable feedback loops.
## The bigger lesson
Evaluation and optimization arenβt one-off experiments; theyβre continuous processes. What makes them sustainable isn't a clever prompt, itβs the platform around it.
Systems succeed when they:
- **Observe** failures with clarity β track exactly what failed and why.
- **Remain durable** across iterations β run pipelines that are stable, reproducible, and comparable over time.
That's where Flyte 2 comes in. Prompt optimization is one lever, but it becomes powerful only when combined with:
- Clean, human-validated evaluation datasets.
- Systematic reporting and feedback loops.
**The real takeaway: improving LLM pipelines isn't about chasing the perfect prompt. It's about designing workflows with observability and durability at the core, so that every experiment compounds into long-term progress.**
=== PAGE: https://www.union.ai/docs/v2/byoc/integrations ===
# Integrations
Flyte is designed to be highly extensible and can be customized
in multiple ways.
## Flyte Plugins
Flyte plugins extend the functionality of the `flyte` SDK.
| Plugin | Description |
| ------ | ----------- |
| **Flyte plugins > Ray** | Run Ray jobs on your Flyte cluster |
| **Flyte plugins > Spark** | Run Spark jobs on your Flyte cluster |
| **Flyte plugins > OpenAI** | Integrate with OpenAI SDKs in your Flyte workflows |
| **Flyte plugins > Dask** | Run Dask jobs on your Flyte cluster |
## Subpages
- **Connectors**
- **Flyte plugins**
=== PAGE: https://www.union.ai/docs/v2/byoc/integrations/connectors ===
# Connectors
Connectors are stateless, longβrunning services that receive execution requests via gRPC and then submit work to external (or internal) systems. Each connector runs as its own Kubernetes deployment, and is triggered when a Flyte task of the matching type is executed. For example: when a `BigQueryTask` is launched, the BigQuery connector receives the request and creates a BigQuery job.
> [!NOTE]
> The first connector for Flyte 2, the BigQuery connector (and the matching `BigQueryTask`), is in development and will be available soon.
Although they normally run inside the control plane, you can also run connectors locally β as long as the required secrets/credentials are present β because connectors are just Python services that can be spawned inβprocess.
Connectors are designed to scale horizontally and reduce load on the core Flyte backend because they execute *outside* the core system. This decoupling makes connectors efficient, resilient, and easy to iterate on. You can even test them locally without modifying backend configuration, which reduces friction during development.
## Creating a new connector
If none of the existing connectors meet your needs, you can build your own.
> [!NOTE]
> Connectors communicate via Protobuf, so in theory they can be implemented in any language.
> Today, only **Python** connectors are supported.
### Async connector interface
To implement a new async connector, extend `AsyncConnector` and implement the following methods β all of which **must be idempotent**:
| Method | Purpose |
|----------|-------------------------------------------------------------|
| `create` | Launch the external job (via REST, gRPC, SDK, or other API) |
| `get` | Fetch current job state (return job status or output) |
| `delete` | Delete / cancel the external job |
To test the connector locally, the connector task should inherit from
[AsyncConnectorExecutorMixin](https://github.com/flyteorg/flyte-sdk/blob/1d49299294cd5e15385fe8c48089b3454b7a4cd1/src/flyte/connectors/_connector.py#L206).
This mixin simulates how the Flyte system executes asynchronous connector tasks, making it easier to validate your connector implementation before deploying it.
```python
from dataclasses import dataclass
from flyte.connectors import AsyncConnector, Resource, ResourceMeta
from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
from flyteidl2.core.tasks_pb2 import TaskTemplate
from google.protobuf import json_format
import typing
import httpx
@dataclass
class ModelTrainJobMeta(ResourceMeta):
job_id: str
endpoint: str
class ModelTrainingConnector(AsyncConnector):
"""
Example connector that launches a ML model training job on an external training service.
POST β launch training job
GET β poll training progress
DELETE β cancel training job
"""
name = "Model Training Connector"
task_type_name = "external_model_training"
metadata_type = ModelTrainJobMeta
async def create(
self,
task_template: TaskTemplate,
inputs: typing.Optional[typing.Dict[str, typing.Any]],
**kwargs,
) -> ModelTrainJobMeta:
"""
Submit training job via POST.
Response returns job_id we later use in get().
"""
custom = json_format.MessageToDict(task_template.custom) if task_template.custom else None
async with httpx.AsyncClient() as client:
r = await client.post(
custom["endpoint"],
json={"dataset_uri": inputs["dataset_uri"], "epochs": inputs["epochs"]},
)
r.raise_for_status()
return ModelTrainJobMeta(job_id=r.json()["job_id"], endpoint=custom["endpoint"])
async def get(self, resource_meta: ModelTrainJobMeta, **kwargs) -> Resource:
"""
Poll external API until training job finishes.
Must be safe to call repeatedly.
"""
async with httpx.AsyncClient() as client:
r = await client.get(f"{resource_meta.endpoint}/{resource_meta.job_id}")
if r.status_code != 200:
return Resource(phase=TaskExecution.RUNNING)
data = r.json()
if data["status"] == "finished":
return Resource(
phase=TaskExecution.SUCCEEDED,
log_links=[TaskLog(name="training-dashboard", uri=f"https://example-mltrain.com/train/{resource_meta.job_id}")],
outputs={"results": data["results"]},
)
return Resource(phase=TaskExecution.RUNNING)
async def delete(self, resource_meta: ModelTrainJobMeta, **kwargs):
"""
Optionally call DELETE on external API.
Safe even if job already completed.
"""
async with httpx.AsyncClient() as client:
await client.delete(f"{resource_meta.endpoint}/{resource_meta.job_id}")
```
To actually use this connector, you must also define a task whose `task_type` matches the connector.
```python
import flyte.io
from typing import Any, Dict, Optional
from flyte.extend import TaskTemplate
from flyte.connectors import AsyncConnectorExecutorMixin
from flyte.models import NativeInterface, SerializationContext
class ModelTrainTask(AsyncConnectorExecutorMixin, TaskTemplate):
_TASK_TYPE = "external_model_training"
def __init__(
self,
name: str,
endpoint: str,
**kwargs,
):
super().__init__(
name=name,
interface=NativeInterface(
inputs={"epochs": int, "dataset_uri": str},
outputs={"results": flyte.io.File},
),
task_type=self._TASK_TYPE,
**kwargs,
)
self.endpoint = endpoint
def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
return {"endpoint": self.endpoint}
```
Here is an example of how to use the `ModelTrainTask`:
```python
import flyte
env = flyte.TaskEnvironment(name="hello_world", resources=flyte.Resources(memory="250Mi"))
model_train_task = ModelTrainTask(
name="model_train",
endpoint="https://example-mltrain.com",
)
@env.task
def data_prep() -> str:
return "gs://my-bucket/dataset.csv"
@env.task
def train_model(epochs: int) -> flyte.io.File:
dataset_uri = data_prep()
return model_train_task(epochs=epochs, dataset_uri=dataset_uri)
```
## Build Connector Docker Image
Build the custom image when you're ready to deploy your connector to your cluster.
To build the Docker image for your connector, run the following script:
```python
import asyncio
from flyte import Image
from flyte.extend import ImageBuildEngine
async def build_flyte_connector_image(
registry: str, name: str, builder: str = "local"
):
"""
Build the SDK default connector image, optionally overriding
the container registry and image name.
Args:
registry: e.g. "ghcr.io/my-org" or "123456789012.dkr.ecr.us-west-2.amazonaws.com".
name: e.g. "my-connector".
builder: e.g. "local" or "remote".
"""
default_image = Image.from_debian_base(registry=registry, name=name).with_pip_packages(
"flyteplugins-connectors[bigquery]", pre=True
)
await ImageBuildEngine.build(default_image, builder=builder)
if __name__ == "__main__":
print("Building connector image...")
asyncio.run(build_flyte_connector_image(registry="", name="flyte-connectors", builder="local"))
```
## Enabling a connector in your Union.ai deployment
To enable a connector in your Union.ai deployment, contact the Union.ai team.
=== PAGE: https://www.union.ai/docs/v2/byoc/integrations/flyte-plugins ===
# Flyte plugins
Flyte is designed to be extensible, allowing you to integrate new tools and frameworks into your workflows. By installing and configuring plugins, you can tailor Flyte to your data and compute ecosystem β whether you need to run large-scale distributed training, process data with a specific engine, or interact with external APIs.
Common reasons to extend Flyte include:
- **Specialized compute:** Use plugins like Spark or Ray to create distributed compute clusters.
- **AI integration:** Connect Flyte with frameworks like OpenAI to run LLM agentic applications.
- **Custom infrastructure:** Add plugins to interface with your organizationβs storage, databases, or proprietary systems.
For example, you can install the PyTorch plugin to run distributed PyTorch jobs natively on a Kubernetes cluster.
| Plugin | Description |
| ------ | ----------- |
| **Flyte plugins > Ray** | Run Ray jobs on your Flyte cluster |
| **Flyte plugins > Spark** | Run Spark jobs on your Flyte cluster |
| **Flyte plugins > OpenAI** | Integrate with OpenAI SDKs in your Flyte workflows |
| **Flyte plugins > Dask** | Run Dask jobs on your Flyte cluster |
## Subpages
- **Flyte plugins > Dask**
- **Flyte plugins > OpenAI**
- **Flyte plugins > Pytorch**
- **Flyte plugins > Ray**
- **Flyte plugins > Spark**
=== PAGE: https://www.union.ai/docs/v2/byoc/integrations/flyte-plugins/dask ===
# Dask
Flyte can execute Dask jobs natively on a Kubernetes Cluster,
which manages a clusterβs lifecycle, spin-up, and tear down. It leverages
the open-sourced Dask Kubernetes Operator and can be enabled without signing up for
any service. This is like running a transient Dask cluster β a type of cluster
spun up for a specific Dask job and torn down after completion.
To install the plugin, run the following command:
## Install the plugin
To install the Dask plugin, run the following command:
```shell
$ pip install --pre flyteplugins-dask
```
The following example shows how to configure Dask in a `TaskEnvironment`. Flyte automatically provisions a Dask cluster for each task using this configuration:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-dask",
# "distributed"
# ]
# main = "hello_dask_nested"
# params = ""
# ///
import asyncio
import typing
from distributed import Client
from flyteplugins.dask import Dask, Scheduler, WorkerGroup
import flyte.remote
import flyte.storage
from flyte import Resources
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("flyteplugins-dask")
dask_config = Dask(
scheduler=Scheduler(),
workers=WorkerGroup(number_of_workers=4),
)
task_env = flyte.TaskEnvironment(
name="hello_dask", resources=Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
dask_env = flyte.TaskEnvironment(
name="dask_env",
plugin_config=dask_config,
image=image,
resources=Resources(cpu="1", memory="1Gi"),
depends_on=[task_env],
)
@task_env.task()
async def hello_dask():
await asyncio.sleep(5)
print("Hello from the Dask task!")
@dask_env.task
async def hello_dask_nested(n: int = 3) -> typing.List[int]:
print("running dask task")
t = asyncio.create_task(hello_dask())
client = Client()
futures = client.map(lambda x: x + 1, range(n))
res = client.gather(futures)
await t
return res
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_dask_nested)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/dask/dask_example.py)
=== PAGE: https://www.union.ai/docs/v2/byoc/integrations/flyte-plugins/openai ===
# OpenAI
Flyte can integrate with OpenAI SDKs in your Flyte workflows.
It provides drop-in replacements for OpenAI SDKs like `openai-agents` so that
you can build LLM-augmented workflows and agentic applications on Flyte.
## Install the plugin
To install the OpenAI plugin, run the following command:
```bash
pip install --pre flyteplugins-openai
```
## Subpages
- **Flyte plugins > OpenAI > Agent tools**
=== PAGE: https://www.union.ai/docs/v2/byoc/integrations/flyte-plugins/openai/agent_tools ===
# Agent tools
In this example, we will use the `openai-agents` library to create a simple agent that can use tools to perform tasks.
This example is based on the [basic tools example](https://github.com/openai/openai-agents-python/blob/main/examples/basic/tools.py) example from the `openai-agents-python` repo.
First, create an OpenAI API key, which you can get from the [OpenAI website](https://platform.openai.com/account/api-keys).
Then, create a secret on your Flyte cluster with:
```
flyte create secret OPENAI_API_KEY --value
```
Then, we'll use `uv script` to specify our dependencies.
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py)
Next, we'll import the libraries and create a `TaskEnvironment`, which we need to run the example:
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py)
## Define the tools
We'll define a tool that can get weather information for a
given city. In this case, we'll use a toy function that returns a hard-coded `Weather` object.
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py)
In this code snippet, the `@function_tool` decorator is imported from `flyteplugins.openai.agents`, which is a drop-in replacement for the `@function_tool` decorator from `openai-agents` library.
## Define the agent
Then, we'll define the agent, which calls the tool:
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py)
## Run the agent
Finally, we'll run the agent. Create `config.yaml` file, which the `flyte.init_from_config()` function will use to connect to
the Flyte cluster:
```bash
flyte create config \
--output ~/.flyte/config.yaml \
--endpoint demo.hosted.unionai.cloud/ \
--project flytesnacks \
--domain development \
--builder remote
```
```
"""OpenAI Agents with Flyte, basic tool example.
Usage:
Create secret:
```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-openai>=2.0.0b7",
# "openai-agents>=0.2.4",
# "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///
# {{/docs-fragment uv-script}}
# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel
import flyte
from flyteplugins.openai.agents import function_tool
env = flyte.TaskEnvironment(
name="openai_agents_tools",
resources=flyte.Resources(cpu=1, memory="250Mi"),
image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)
# {{/docs-fragment imports-task-env}}
# {{docs-fragment tools}}
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
@env.task
async def get_weather(city: str) -> Weather:
"""Get the weather for a given city."""
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
# {{/docs-fragment tools}}
# {{docs-fragment agent}}
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
@env.task
async def main() -> str:
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
return result.final_output
# {{/docs-fragment agent}}
# {{docs-fragment main}}
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
# {{/docs-fragment main}}
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py)
## Conclusion
In this example, we've seen how to use the `openai-agents` library to create a simple agent that can use tools to perform tasks.
The full code is available [here](https://github.com/unionai/unionai-examples/tree/main/v2/integrations/flyte-plugins/openai/openai).
=== PAGE: https://www.union.ai/docs/v2/byoc/integrations/flyte-plugins/pytorch ===
# Pytorch
Flyte can execute distributed PyTorch jobs (which is similar to Running a torchrun script) natively on a Kubernetes Cluster,
which manages a clusterβs lifecycle, spin-up, and tear down.
It leverages the open-sourced Kubeflow Operator.
This is like running a transient Pytorch cluster β a type of cluster
spun up for a specific Pytorch job and torn down after completion.
To install the plugin, run the following command:
```shell
$ pip install --pre flyteplugins-pytorch
```
The following example shows how to configure Pytorch in a `TaskEnvironment`. Flyte automatically provisions a Pytorch cluster for each task using this configuration:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-pytorch",
# "torch"
# ]
# main = "torch_distributed_train"
# params = "3"
# ///
import typing
import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from flyteplugins.pytorch.task import Elastic
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
import flyte
image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch", pre=True)
torch_env = flyte.TaskEnvironment(
name="torch_env",
resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
plugin_config=Elastic(
nproc_per_node=1,
# if you want to do local testing set nnodes=1
nnodes=2,
),
image=image,
)
class LinearRegressionModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
def prepare_dataloader(rank: int, world_size: int, batch_size: int = 2) -> DataLoader:
"""
Prepare a DataLoader with a DistributedSampler so each rank
gets a shard of the dataset.
"""
# Dummy dataset
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])
dataset = TensorDataset(x_train, y_train)
# Distributed-aware sampler
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
return DataLoader(dataset, batch_size=batch_size, sampler=sampler)
def train_loop(epochs: int = 3) -> float:
"""
A simple training loop for linear regression.
"""
torch.distributed.init_process_group("gloo")
model = DDP(LinearRegressionModel())
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
dataloader = prepare_dataloader(
rank=rank,
world_size=world_size,
batch_size=64,
)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
final_loss = 0.0
for _ in range(epochs):
for x, y in dataloader:
outputs = model(x)
loss = criterion(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
final_loss = loss.item()
if torch.distributed.get_rank() == 0:
print(f"Loss: {final_loss}")
return final_loss
@torch_env.task
def torch_distributed_train(epochs: int) -> typing.Optional[float]:
"""
A nested task that sets up a simple distributed training job using PyTorch's
"""
print("starting launcher")
loss = train_loop(epochs=epochs)
print("Training complete")
return loss
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(torch_distributed_train, epochs=3)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/pytorch/pytorch_example.py)
=== PAGE: https://www.union.ai/docs/v2/byoc/integrations/flyte-plugins/ray ===
# Ray
Flyte can execute Ray jobs natively on a Kubernetes Cluster,
which manages a virtual clusterβs lifecycle, spin-up, and tear down.
It leverages the open-sourced https://github.com/ray-project/kuberay and can be
enabled without signing up for any service. This is like running a transient Ray
cluster β a type of cluster spun up for a specific Ray job and torn down after
completion.
To install the plugin, run the following command:
## Install the plugin
To install the Ray plugin, run the following command:
```shell
$ pip install --pre flyteplugins-ray
```
The following example shows how to configure Ray in a `TaskEnvironment`. Flyte automatically provisions a Ray cluster for each task using this configuration:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-ray",
# "ray[default]==2.46.0"
# ]
# main = "hello_ray_nested"
# params = "3"
# ///
import asyncio
import typing
import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
import flyte.remote
import flyte.storage
@ray.remote
def f(x):
return x * x
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
runtime_env={"pip": ["numpy", "pandas"]},
enable_autoscaling=False,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=300,
)
image = (
flyte.Image.from_debian_base(name="ray")
.with_apt_packages("wget")
.with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray", "pip", "mypy")
)
task_env = flyte.TaskEnvironment(
name="hello_ray", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
name="ray_env",
plugin_config=ray_config,
image=image,
resources=flyte.Resources(cpu=(3, 4), memory=("3000Mi", "5000Mi")),
depends_on=[task_env],
)
@task_env.task()
async def hello_ray():
await asyncio.sleep(20)
print("Hello from the Ray task!")
@ray_env.task
async def hello_ray_nested(n: int = 3) -> typing.List[int]:
print("running ray task")
t = asyncio.create_task(hello_ray())
futures = [f.remote(i) for i in range(n)]
res = ray.get(futures)
await t
return res
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_ray_nested)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/ray/ray_example.py)
The next example demonstrates how Flyte can create ephemeral Ray clusters and run a subtask that connects to an existing Ray cluster:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-ray",
# "ray[default]==2.46.0"
# ]
# main = "create_ray_cluster"
# params = ""
# ///
import os
import typing
import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
import flyte.storage
@ray.remote
def f(x):
return x * x
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
enable_autoscaling=False,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=3600,
)
image = (
flyte.Image.from_debian_base(name="ray")
.with_apt_packages("wget")
.with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray")
)
task_env = flyte.TaskEnvironment(
name="ray_client", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
name="ray_cluster",
plugin_config=ray_config,
image=image,
resources=flyte.Resources(cpu=(2, 4), memory=("2000Mi", "4000Mi")),
depends_on=[task_env],
)
@task_env.task()
async def hello_ray(cluster_ip: str) -> typing.List[int]:
"""
Run a simple Ray task that connects to an existing Ray cluster.
"""
ray.init(address=f"ray://{cluster_ip}:10001")
futures = [f.remote(i) for i in range(5)]
res = ray.get(futures)
return res
@ray_env.task
async def create_ray_cluster() -> str:
"""
Create a Ray cluster and return the head node IP address.
"""
print("creating ray cluster")
cluster_ip = os.getenv("MY_POD_IP")
if cluster_ip is None:
raise ValueError("MY_POD_IP environment variable is not set")
return f"{cluster_ip}"
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(create_ray_cluster)
run.wait()
print("run url:", run.url)
print("cluster created, running ray task")
print("ray address:", run.outputs()[0])
run = flyte.run(hello_ray, cluster_ip=run.outputs()[0])
print("run url:", run.url)
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/ray/ray_existing_example.py)
=== PAGE: https://www.union.ai/docs/v2/byoc/integrations/flyte-plugins/spark ===
# Spark
Flyte can execute Spark jobs natively on a Kubernetes Cluster,
which manages a virtual clusterβs lifecycle, spin-up, and tear down. It leverages
the open-sourced Spark On K8s Operator and can be enabled without signing up for
any service. This is like running a transient Spark cluster β a type of cluster
spun up for a specific Spark job and torn down after completion.
To install the plugin, run the following command:
```bash
pip install --pre flyteplugins-spark
```
The following example shows how to configure Spark in a `TaskEnvironment`. Flyte automatically provisions a Spark cluster for each task using this configuration:
```python
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-spark"
# ]
# main = "hello_spark_nested"
# params = "3"
# ///
import random
from copy import deepcopy
from operator import add
from flyteplugins.spark.task import Spark
import flyte.remote
from flyte._context import internal_ctx
image = (
flyte.Image.from_base("apache/spark-py:v3.4.0")
.clone(name="spark", python_version=(3, 10), registry="ghcr.io/flyteorg")
.with_pip_packages("flyteplugins-spark", pre=True)
)
task_env = flyte.TaskEnvironment(
name="get_pi", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "3000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
"spark.kubernetes.file.upload.path": "/opt/spark/work-dir",
"spark.jars": "https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar,https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.2.2/hadoop-aws-3.2.2.jar,https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar",
},
)
spark_env = flyte.TaskEnvironment(
name="spark_env",
resources=flyte.Resources(cpu=(1, 2), memory=("3000Mi", "5000Mi")),
plugin_config=spark_conf,
image=image,
depends_on=[task_env],
)
def f(_):
x = random.random() * 2 - 1
y = random.random() * 2 - 1
return 1 if x**2 + y**2 <= 1 else 0
@task_env.task
async def get_pi(count: int, partitions: int) -> float:
return 4.0 * count / partitions
@spark_env.task
async def hello_spark_nested(partitions: int = 3) -> float:
n = 1 * partitions
ctx = internal_ctx()
spark = ctx.data.task_context.data["spark_session"]
count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
return await get_pi(count, partitions)
@task_env.task
async def spark_overrider(executor_instances: int = 3, partitions: int = 4) -> float:
updated_spark_conf = deepcopy(spark_conf)
updated_spark_conf.spark_conf["spark.executor.instances"] = str(executor_instances)
return await hello_spark_nested.override(plugin_config=updated_spark_conf)(partitions=partitions)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_spark_nested)
print(r.name)
print(r.url)
r.wait()
```
(Source code for the above example: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/spark/spark_example.py)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference ===
# Reference
This section provides the reference material for the Flyte SDK and CLI.
To get started, add `flyte` to your project
```shell
$ uv pip install --no-cache --prerelease=allow --upgrade flyte
```
This will install both the Flyte SDK and CLI.
### π **Flyte SDK**
The Flyte SDK provides the core Python API for building workflows and apps on your Union instance.
### π **Flyte CLI**
The Flyte CLI is the command-line interface for interacting with your Union instance.
## Subpages
- **Flyte CLI**
- **LLM context document**
- **Flyte SDK**
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-cli ===
# Flyte CLI
This is the command line interface for Flyte.
| Object | Action |
| ------ | -- |
| `run` | **Flyte CLI > flyte > flyte abort > flyte abort run**, **Flyte CLI > flyte > flyte get > flyte get run** |
| `api-key` | **Flyte CLI > flyte > flyte create > flyte create api-key**, **Flyte CLI > flyte > flyte delete > flyte delete api-key**, **Flyte CLI > flyte > flyte get > flyte get api-key** |
| `config` | **Flyte CLI > flyte > flyte create > flyte create config**, **Flyte CLI > flyte > flyte get > flyte get config** |
| `secret` | **Flyte CLI > flyte > flyte create > flyte create secret**, **Flyte CLI > flyte > flyte delete > flyte delete secret**, **Flyte CLI > flyte > flyte get > flyte get secret** |
| `trigger` | **Flyte CLI > flyte > flyte create > flyte create trigger**, **Flyte CLI > flyte > flyte delete > flyte delete trigger**, **Flyte CLI > flyte > flyte get > flyte get trigger**, **Flyte CLI > flyte > flyte update > flyte update trigger** |
| `docs` | **Flyte CLI > flyte > flyte gen > flyte gen docs** |
| `action` | **Flyte CLI > flyte > flyte get > flyte get action** |
| `app` | **Flyte CLI > flyte > flyte get > flyte get app**, **Flyte CLI > flyte > flyte update > flyte update app** |
| `io` | **Flyte CLI > flyte > flyte get > flyte get io** |
| `logs` | **Flyte CLI > flyte > flyte get > flyte get logs** |
| `project` | **Flyte CLI > flyte > flyte get > flyte get project** |
| `task` | **Flyte CLI > flyte > flyte get > flyte get task** |
| `deployed-task` | **Flyte CLI > flyte > flyte run > flyte run deployed-task** |
**βΊ** Plugin command - see command documentation for installation instructions
| Action | On |
| ------ | -- |
| `abort` | **Flyte CLI > flyte > flyte abort > flyte abort run** |
| **Flyte CLI > flyte > flyte build** | - |
| `create` | **Flyte CLI > flyte > flyte create > flyte create api-key**, **Flyte CLI > flyte > flyte create > flyte create config**, **Flyte CLI > flyte > flyte create > flyte create secret**, **Flyte CLI > flyte > flyte create > flyte create trigger** |
| `delete` | **Flyte CLI > flyte > flyte delete > flyte delete api-key**, **Flyte CLI > flyte > flyte delete > flyte delete secret**, **Flyte CLI > flyte > flyte delete > flyte delete trigger** |
| **Flyte CLI > flyte > flyte deploy** | - |
| `gen` | **Flyte CLI > flyte > flyte gen > flyte gen docs** |
| `get` | **Flyte CLI > flyte > flyte get > flyte get action**, **Flyte CLI > flyte > flyte get > flyte get api-key**, **Flyte CLI > flyte > flyte get > flyte get app**, **Flyte CLI > flyte > flyte get > flyte get config**, **Flyte CLI > flyte > flyte get > flyte get io**, **Flyte CLI > flyte > flyte get > flyte get logs**, **Flyte CLI > flyte > flyte get > flyte get project**, **Flyte CLI > flyte > flyte get > flyte get run**, **Flyte CLI > flyte > flyte get > flyte get secret**, **Flyte CLI > flyte > flyte get > flyte get task**, **Flyte CLI > flyte > flyte get > flyte get trigger** |
| `run` | **Flyte CLI > flyte > flyte run > flyte run deployed-task** |
| **Flyte CLI > flyte > flyte serve** | - |
| `update` | **Flyte CLI > flyte > flyte update > flyte update app**, **Flyte CLI > flyte > flyte update > flyte update trigger** |
| **Flyte CLI > flyte > flyte whoami** | - |
**βΊ** Plugin command - see command documentation for installation instructions
## flyte
**`flyte [OPTIONS] COMMAND [ARGS]...`**
The Flyte CLI is the command line interface for working with the Flyte SDK and backend.
It follows a simple verb/noun structure,
where the top-level commands are verbs that describe the action to be taken,
and the subcommands are nouns that describe the object of the action.
The root command can be used to configure the CLI for persistent settings,
such as the endpoint, organization, and verbosity level.
Set endpoint and organization:
```bash
$ flyte --endpoint --org get project
```
Increase verbosity level (This is useful for debugging,
this will show more logs and exception traces):
```bash
$ flyte -vvv get logs
```
Override the default config file:
```bash
$ flyte --config /path/to/config.yaml run ...
```
* [Documentation](https://www.union.ai/docs/flyte/user-guide/)
* [GitHub](https://github.com/flyteorg/flyte): Please leave a star if you like Flyte!
* [Slack](https://slack.flyte.org): Join the community and ask questions.
* [Issues](https://github.com/flyteorg/flyte/issues)
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--version` | `boolean` | `False` | Show the version and exit. |
| `--endpoint` | `text` | `Sentinel.UNSET` | The endpoint to connect to. This will override any configuration file and simply use `pkce` to connect. |
| `--insecure` | `boolean` | | Use an insecure connection to the endpoint. If not specified, the CLI will use TLS. |
| `--auth-type` | `choice` | | Authentication type to use for the Flyte backend. Defaults to 'pkce'. |
| `-v`
`--verbose` | `integer` | `0` | Show verbose messages and exception traces. Repeating multiple times increases the verbosity (e.g., -vvv). |
| `--org` | `text` | `Sentinel.UNSET` | The organization to which the command applies. |
| `-c`
`--config` | `path` | `Sentinel.UNSET` | Path to the configuration file to use. If not specified, the default configuration file is used. |
| `--output-format`
`-of` | `choice` | `table` | Output format for commands that support it. Defaults to 'table'. |
| `--log-format` | `choice` | `console` | Formatting for logs, defaults to 'console' which is meant to be human readable. 'json' is meant for machine parsing. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte abort
**`flyte abort COMMAND [ARGS]...`**
Abort an ongoing process.
#### flyte abort run
**`flyte abort run [OPTIONS] RUN_NAME`**
Abort a run.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte build
**`flyte build [OPTIONS] COMMAND [ARGS]...`**
Build the environments defined in a python file or directory. This will build the images associated with the
environments.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--noop` | `boolean` | `Sentinel.UNSET` | Dummy parameter, placeholder for future use. Does not affect the build process. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte create
**`flyte create COMMAND [ARGS]...`**
Create resources in a Flyte deployment.
#### flyte create api-key
> **Note:** This command is provided by the `flyteplugins.union` plugin. See the plugin documentation for installation instructions.
**`flyte create api-key [OPTIONS]`**
Create an API key for headless authentication.
This creates OAuth application credentials that can be used to authenticate
with Union without interactive login. The generated API key should be set
as the FLYTE_API_KEY environment variable. Oauth applications should not be
confused with Union Apps, which are a different construct entirely.
Examples:
# Create an API key named "ci-pipeline"
$ flyte create api-key --name ci-pipeline
# The output will include an export command like:
# export FLYTE_API_KEY=""
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--name` | `text` | `Sentinel.UNSET` | Name for API key |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte create config
**`flyte create config [OPTIONS]`**
Creates a configuration file for Flyte CLI.
If the `--output` option is not specified, it will create a file named `config.yaml` in the current directory.
If the file already exists, it will raise an error unless the `--force` option is used.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--endpoint` | `text` | `Sentinel.UNSET` | Endpoint of the Flyte backend. |
| `--insecure` | `boolean` | `False` | Use an insecure connection to the Flyte backend. |
| `--org` | `text` | `Sentinel.UNSET` | Organization to use. This will override the organization in the configuration file. |
| `-o`
`--output` | `path` | `.flyte/config.yaml` | Path to the output directory where the configuration will be saved. Defaults to current directory. |
| `--force` | `boolean` | `False` | Force overwrite of the configuration file if it already exists. |
| `--image-builder`
`--builder` | `choice` | `local` | Image builder to use for building images. Defaults to 'local'. |
| `--auth-type` | `choice` | | Authentication type to use for the Flyte backend. Defaults to 'pkce'. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte create secret
**`flyte create secret [OPTIONS] NAME`**
Create a new secret. The name of the secret is required. For example:
```bash
$ flyte create secret my_secret --value my_value
```
If you don't provide a `--value` flag, you will be prompted to enter the
secret value in the terminal.
```bash
$ flyte create secret my_secret
Enter secret value:
```
If `--from-file` is specified, the value will be read from the file instead of being provided directly:
```bash
$ flyte create secret my_secret --from-file /path/to/secret_file
```
The `--type` option can be used to create specific types of secrets.
Either `regular` or `image_pull` can be specified.
Secrets intended to access container images should be specified as `image_pull`.
Other secrets should be specified as `regular`.
If no type is specified, `regular` is assumed.
For image pull secrets, you have several options:
1. Interactive mode (prompts for registry, username, password):
```bash
$ flyte create secret my_secret --type image_pull
```
2. With explicit credentials:
```bash
$ flyte create secret my_secret --type image_pull --registry ghcr.io --username myuser
```
3. Lastly, you can create a secret from your existing Docker installation (i.e., you've run `docker login` in
the past) and you just want to pull from those credentials. Since you may have logged in to multiple registries,
you can specify which registries to include. If no registries are specified, all registries are added.
```bash
$ flyte create secret my_secret --type image_pull --from-docker-config --registries ghcr.io,docker.io
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--value` | `text` | `Sentinel.UNSET` | Secret value Mutually exclusive with from_file, from_docker_config, registry. |
| `--from-file` | `path` | `Sentinel.UNSET` | Path to the file with the binary secret. Mutually exclusive with value, from_docker_config, registry. |
| `--type` | `choice` | `regular` | Type of the secret. |
| `--from-docker-config` | `boolean` | `False` | Create image pull secret from Docker config file (only for --type image_pull). Mutually exclusive with value, from_file, registry, username, password. |
| `--docker-config-path` | `path` | `Sentinel.UNSET` | Path to Docker config file (defaults to ~/.docker/config.json or $DOCKER_CONFIG). |
| `--registries` | `text` | `Sentinel.UNSET` | Comma-separated list of registries to include (only with --from-docker-config). |
| `--registry` | `text` | `Sentinel.UNSET` | Registry hostname (e.g., ghcr.io, docker.io) for explicit credentials (only for --type image_pull). Mutually exclusive with value, from_file, from_docker_config. |
| `--username` | `text` | `Sentinel.UNSET` | Username for the registry (only with --registry). |
| `--password` | `text` | `Sentinel.UNSET` | Password for the registry (only with --registry). If not provided, will prompt. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte create trigger
**`flyte create trigger [OPTIONS] TASK_NAME NAME`**
Create a new trigger for a task. The task name and trigger name are required.
Example:
```bash
$ flyte create trigger my_task my_trigger --schedule "0 0 * * *"
```
This will create a trigger that runs every day at midnight.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--schedule` | `text` | `Sentinel.UNSET` | Cron schedule for the trigger. Defaults to every minute. |
| `--description` | `text` | `` | Description of the trigger. |
| `--auto-activate` | `boolean` | `True` | Whether the trigger should not be automatically activated. Defaults to True. |
| `--trigger-time-var` | `text` | `trigger_time` | Variable name for the trigger time in the task inputs. Defaults to 'trigger_time'. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte delete
**`flyte delete COMMAND [ARGS]...`**
Remove resources from a Flyte deployment.
#### flyte delete api-key
> **Note:** This command is provided by the `flyteplugins.union` plugin. See the plugin documentation for installation instructions.
**`flyte delete api-key [OPTIONS] CLIENT_ID`**
Delete an API key.
Examples:
# Delete an API key (with confirmation)
$ flyte delete api-key my-client-id
# Delete without confirmation
$ flyte delete api-key my-client-id --yes
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--yes` | `boolean` | `False` | Skip confirmation prompt |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte delete secret
**`flyte delete secret [OPTIONS] NAME`**
Delete a secret. The name of the secret is required.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte delete trigger
**`flyte delete trigger [OPTIONS] NAME TASK_NAME`**
Delete a trigger. The name of the trigger is required.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte deploy
**`flyte deploy [OPTIONS] COMMAND [ARGS]...`**
Deploy one or more environments from a python file.
This command will create or update environments in the Flyte system, registering
all tasks and their dependencies.
Example usage:
```bash
flyte deploy hello.py my_env
```
Arguments to the deploy command are provided right after the `deploy` command and before the file name.
To deploy all environments in a file, use the `--all` flag:
```bash
flyte deploy --all hello.py
```
To recursively deploy all environments in a directory and its subdirectories, use the `--recursive` flag:
```bash
flyte deploy --recursive ./src
```
You can combine `--all` and `--recursive` to deploy everything:
```bash
flyte deploy --all --recursive ./src
```
You can provide image mappings with `--image` flag. This allows you to specify
the image URI for the task environment during CLI execution without changing
the code. Any images defined with `Image.from_ref_name("name")` will resolve to the
corresponding URIs you specify here.
```bash
flyte deploy --image my_image=ghcr.io/myorg/my-image:v1.0 hello.py my_env
```
If the image name is not provided, it is regarded as a default image and will
be used when no image is specified in TaskEnvironment:
```bash
flyte deploy --image ghcr.io/myorg/default-image:latest hello.py my_env
```
You can specify multiple image arguments:
```bash
flyte deploy --image ghcr.io/org/default:latest --image gpu=ghcr.io/org/gpu:v2.0 hello.py my_env
```
To deploy a specific version, use the `--version` flag:
```bash
flyte deploy --version v1.0.0 hello.py my_env
```
To preview what would be deployed without actually deploying, use the `--dry-run` flag:
```bash
flyte deploy --dry-run hello.py my_env
```
You can specify the `--config` flag to point to a specific Flyte cluster:
```bash
flyte deploy --config my-config.yaml hello.py my_env
```
You can override the default configured project and domain:
```bash
flyte deploy --project my-project --domain development hello.py my_env
```
If loading some files fails during recursive deployment, you can use the `--ignore-load-errors` flag
to continue deploying the environments that loaded successfully:
```bash
flyte deploy --recursive --ignore-load-errors ./src
```
Other arguments to the deploy command are listed below.
To see the environments available in a file, use `--help` after the file name:
```bash
flyte deploy hello.py --help
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--version` | `text` | `Sentinel.UNSET` | Version of the environment to deploy |
| `--dry-run`
`--dryrun` | `boolean` | `False` | Dry run. Do not actually call the backend service. |
| `--copy-style` | `choice` | `loaded_modules` | Copy style to use when running the task |
| `--root-dir` | `text` | `Sentinel.UNSET` | Override the root source directory, helpful when working with monorepos. |
| `--recursive`
`-r` | `boolean` | `False` | Recursively deploy all environments in the current directory |
| `--all` | `boolean` | `False` | Deploy all environments in the current directory, ignoring the file name |
| `--ignore-load-errors`
`-i` | `boolean` | `False` | Ignore errors when loading environments especially when using --recursive or --all. |
| `--no-sync-local-sys-paths` | `boolean` | `False` | Disable synchronization of local sys.path entries under the root directory to the remote container. |
| `--image` | `text` | `Sentinel.UNSET` | Image to be used in the run. Format: imagename=imageuri. Can be specified multiple times. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte gen
**`flyte gen COMMAND [ARGS]...`**
Generate documentation.
#### flyte gen docs
**`flyte gen docs [OPTIONS]`**
Generate documentation.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--type` | `text` | `Sentinel.UNSET` | Type of documentation (valid: markdown) |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte get
**`flyte get COMMAND [ARGS]...`**
Retrieve resources from a Flyte deployment.
You can get information about projects, runs, tasks, actions, secrets, logs and input/output values.
Each command supports optional parameters to filter or specify the resource you want to retrieve.
Using a `get` subcommand without any arguments will retrieve a list of available resources to get.
For example:
* `get project` (without specifying a project), will list all projects.
* `get project my_project` will return the details of the project named `my_project`.
In some cases, a partially specified command will act as a filter and return available further parameters.
For example:
* `get action my_run` will return all actions for the run named `my_run`.
* `get action my_run my_action` will return the details of the action named `my_action` for the run `my_run`.
#### flyte get action
**`flyte get action [OPTIONS] RUN_NAME [ACTION_NAME]`**
Get all actions for a run or details for a specific action.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get api-key
> **Note:** This command is provided by the `flyteplugins.union` plugin. See the plugin documentation for installation instructions.
**`flyte get api-key [OPTIONS] [CLIENT_ID]`**
Get or list API keys.
If CLIENT-ID is provided, gets a specific API key.
Otherwise, lists all API keys.
Examples:
# List all API keys
$ flyte get api-key
# List with a limit
$ flyte get api-key --limit 10
# Get a specific API key
$ flyte get api-key my-client-id
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Maximum number of keys to list |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get app
**`flyte get app [OPTIONS] [NAME]`**
Get a list of all apps, or details of a specific app by name.
Apps are long-running services deployed on the Flyte platform.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of apps to fetch when listing. |
| `--only-mine` | `boolean` | `False` | Show only apps created by the current user (you). |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get config
**`flyte get config`**
Shows the automatically detected configuration to connect with the remote backend.
The configuration will include the endpoint, organization, and other settings that are used by the CLI.
#### flyte get io
**`flyte get io [OPTIONS] RUN_NAME [ACTION_NAME]`**
Get the inputs and outputs of a run or action.
If only the run name is provided, it will show the inputs and outputs of the root action of that run.
If an action name is provided, it will show the inputs and outputs for that action.
If `--inputs-only` or `--outputs-only` is specified, it will only show the inputs or outputs respectively.
Examples:
```bash
$ flyte get io my_run
```
```bash
$ flyte get io my_run my_action
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--inputs-only`
`-i` | `boolean` | `False` | Show only inputs |
| `--outputs-only`
`-o` | `boolean` | `False` | Show only outputs |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get logs
**`flyte get logs [OPTIONS] RUN_NAME [ACTION_NAME]`**
Stream logs for the provided run or action.
If only the run is provided, only the logs for the parent action will be streamed:
```bash
$ flyte get logs my_run
```
If you want to see the logs for a specific action, you can provide the action name as well:
```bash
$ flyte get logs my_run my_action
```
By default, logs will be shown in the raw format and will scroll the terminal.
If automatic scrolling and only tailing `--lines` number of lines is desired, use the `--pretty` flag:
```bash
$ flyte get logs my_run my_action --pretty --lines 50
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--lines`
`-l` | `integer` | `30` | Number of lines to show, only useful for --pretty |
| `--show-ts` | `boolean` | `False` | Show timestamps |
| `--pretty` | `boolean` | `False` | Show logs in an auto-scrolling box, where number of lines is limited to `--lines` |
| `--attempt`
`-a` | `integer` | | Attempt number to show logs for, defaults to the latest attempt. |
| `--filter-system` | `boolean` | `False` | Filter all system logs from the output. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get project
**`flyte get project [NAME]`**
Get a list of all projects, or details of a specific project by name.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get run
**`flyte get run [OPTIONS] [NAME]`**
Get a list of all runs, or details of a specific run by name.
The run details will include information about the run, its status, but only the root action will be shown.
If you want to see the actions for a run, use `get action `.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of runs to fetch when listing. |
| `--in-phase` | `choice` | `Sentinel.UNSET` | Filter runs by their status. |
| `--only-mine` | `boolean` | `False` | Show only runs created by the current user (you). |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get secret
**`flyte get secret [OPTIONS] [NAME]`**
Get a list of all secrets, or details of a specific secret by name.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get task
**`flyte get task [OPTIONS] [NAME] [VERSION]`**
Retrieve a list of all tasks, or details of a specific task by name and version.
Currently, both `name` and `version` are required to get a specific task.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of tasks to fetch. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte get trigger
**`flyte get trigger [OPTIONS] [TASK_NAME] [NAME]`**
Get a list of all triggers, or details of a specific trigger by name.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--limit` | `integer` | `100` | Limit the number of triggers to fetch. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte run
**`flyte run [OPTIONS] COMMAND [ARGS]...`**
Run a task from a python file or deployed task.
Example usage:
```bash
flyte run hello.py my_task --arg1 value1 --arg2 value2
```
Arguments to the run command are provided right after the `run` command and before the file name.
Arguments for the task itself are provided after the task name.
To run a task locally, use the `--local` flag. This will run the task in the local environment instead of the remote
Flyte environment:
```bash
flyte run --local hello.py my_task --arg1 value1 --arg2 value2
```
You can provide image mappings with `--image` flag. This allows you to specify
the image URI for the task environment during CLI execution without changing
the code. Any images defined with `Image.from_ref_name("name")` will resolve to the
corresponding URIs you specify here.
```bash
flyte run --image my_image=ghcr.io/myorg/my-image:v1.0 hello.py my_task
```
If the image name is not provided, it is regarded as a default image and will
be used when no image is specified in TaskEnvironment:
```bash
flyte run --image ghcr.io/myorg/default-image:latest hello.py my_task
```
You can specify multiple image arguments:
```bash
flyte run --image ghcr.io/org/default:latest --image gpu=ghcr.io/org/gpu:v2.0 hello.py my_task
```
To run tasks that you've already deployed to Flyte, use the deployed-task command:
```bash
flyte run deployed-task my_env.my_task --arg1 value1 --arg2 value2
```
To run a specific version of a deployed task, use the `env.task:version` syntax:
```bash
flyte run deployed-task my_env.my_task:xyz123 --arg1 value1 --arg2 value2
```
You can specify the `--config` flag to point to a specific Flyte cluster:
```bash
flyte run --config my-config.yaml deployed-task ...
```
You can override the default configured project and domain:
```bash
flyte run --project my-project --domain development hello.py my_task
```
You can discover what deployed tasks are available by running:
```bash
flyte run deployed-task
```
Other arguments to the run command are listed below.
Arguments for the task itself are provided after the task name and can be retrieved using `--help`. For example:
```bash
flyte run hello.py my_task --help
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--local` | `boolean` | `False` | Run the task locally |
| `--copy-style` | `choice` | `loaded_modules` | Copy style to use when running the task |
| `--root-dir` | `text` | `Sentinel.UNSET` | Override the root source directory, helpful when working with monorepos. |
| `--raw-data-path` | `text` | `Sentinel.UNSET` | Override the output prefix used to store offloaded data types. e.g. s3://bucket/ |
| `--service-account` | `text` | `Sentinel.UNSET` | Kubernetes service account. If not provided, the configured default will be used |
| `--name` | `text` | `Sentinel.UNSET` | Name of the run. If not provided, a random name will be generated. |
| `--follow`
`-f` | `boolean` | `False` | Wait and watch logs for the parent action. If not provided, the CLI will exit after successfully launching a remote execution with a link to the UI. |
| `--image` | `text` | `Sentinel.UNSET` | Image to be used in the run. Format: imagename=imageuri. Can be specified multiple times. |
| `--no-sync-local-sys-paths` | `boolean` | `False` | Disable synchronization of local sys.path entries under the root directory to the remote container. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte run deployed-task
**`flyte run deployed-task [OPTIONS] COMMAND [ARGS]...`**
Run reference task from the Flyte backend
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte serve
**`flyte serve [OPTIONS] COMMAND [ARGS]...`**
Serve an app from a Python file using flyte.serve().
This command allows you to serve apps defined with `flyte.app.AppEnvironment`
in your Python files. The serve command will deploy the app to the Flyte backend
and start it, making it accessible via a URL.
Example usage:
```bash
flyte serve examples/apps/basic_app.py app_env
```
Arguments to the serve command are provided right after the `serve` command and before the file name.
To follow the logs of the served app, use the `--follow` flag:
```bash
flyte serve --follow examples/apps/basic_app.py app_env
```
Note: Log streaming is not yet fully implemented and will be added in a future release.
You can provide image mappings with `--image` flag. This allows you to specify
the image URI for the app environment during CLI execution without changing
the code. Any images defined with `Image.from_ref_name("name")` will resolve to the
corresponding URIs you specify here.
```bash
flyte serve --image my_image=ghcr.io/myorg/my-image:v1.0 examples/apps/basic_app.py app_env
```
If the image name is not provided, it is regarded as a default image and will
be used when no image is specified in AppEnvironment:
```bash
flyte serve --image ghcr.io/myorg/default-image:latest examples/apps/basic_app.py app_env
```
You can specify multiple image arguments:
```bash
flyte serve --image ghcr.io/org/default:latest --image gpu=ghcr.io/org/gpu:v2.0 examples/apps/basic_app.py app_env
```
You can specify the `--config` flag to point to a specific Flyte cluster:
```bash
flyte serve --config my-config.yaml examples/apps/basic_app.py app_env
```
You can override the default configured project and domain:
```bash
flyte serve --project my-project --domain development examples/apps/basic_app.py app_env
```
Other arguments to the serve command are listed below.
Note: This pattern is primarily useful for serving apps defined in tasks.
Serving deployed apps is not currently supported through this CLI command.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--copy-style` | `choice` | `loaded_modules` | Copy style to use when serving the app |
| `--root-dir` | `text` | `Sentinel.UNSET` | Override the root source directory, helpful when working with monorepos. |
| `--service-account` | `text` | `Sentinel.UNSET` | Kubernetes service account. If not provided, the configured default will be used |
| `--name` | `text` | `Sentinel.UNSET` | Name of the app deployment. If not provided, the app environment name will be used. |
| `--follow`
`-f` | `boolean` | `False` | Wait and watch logs for the app. If not provided, the CLI will exit after successfully deploying the app with a link to the UI. |
| `--image` | `text` | `Sentinel.UNSET` | Image to be used in the serve. Format: imagename=imageuri. Can be specified multiple times. |
| `--no-sync-local-sys-paths` | `boolean` | `False` | Disable synchronization of local sys.path entries under the root directory to the remote container. |
| `--env-var`
`-e` | `text` | `Sentinel.UNSET` | Environment variable to set in the app. Format: KEY=VALUE. Can be specified multiple times. Example: --env-var LOG_LEVEL=DEBUG --env-var DATABASE_URL=postgresql://... |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte update
**`flyte update COMMAND [ARGS]...`**
Update various flyte entities.
#### flyte update app
**`flyte update app [OPTIONS] NAME`**
Update an app by starting or stopping it.
Example usage:
```bash
flyte update app --activate | --deactivate [--wait] [--project ] [--domain ]
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--activate`
`--deactivate` | `boolean` | | Activate or deactivate app. |
| `--wait` | `boolean` | `False` | Wait for the app to reach the desired state. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
#### flyte update trigger
**`flyte update trigger [OPTIONS] NAME TASK_NAME`**
Update a trigger.
Example usage:
```bash
flyte update trigger --activate | --deactivate
[--project --domain ]
```
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `--activate`
`--deactivate` | `boolean` | `Sentinel.UNSET` | Activate or deactivate the trigger. |
| `-p`
`--project` | `text` | | Project to which this command applies. |
| `-d`
`--domain` | `text` | | Domain to which this command applies. |
| `--help` | `boolean` | `False` | Show this message and exit. |
### flyte whoami
**`flyte whoami`**
Display the current user information.
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-context ===
# LLM context document
The following document provides a LLM context for authoring and running Flyte/Union workflows.
They can serve as a reference for LLM-based AI assistants to understand how to properly write, configure, and execute Flyte/Union workflows.
* **Full documentation content**: The entire documentation (this site) for Union.ai version 2.0 in a single text file.
* π₯ [llms-full.txt](/_static/public/llms-full.txt)
You can add it to the context window of your LLM-based AI assistant to help it better understand Flyte/Union development.
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk ===
# Flyte SDK
These are the docs for Flyte SDK version 2.0
Flyte is the core Python SDK for the Union and Flyte platforms.
## Subpages
- **Flyte SDK > Classes**
- **Flyte SDK > Packages**
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/classes ===
# Classes
| Class | Description |
|-|-|
| [`flyte.Cache`](../packages/flyte/cache) |Cache configuration for a task. |
| [`flyte.Cron`](../packages/flyte/cron) |This class defines a Cron automation that can be associated with a Trigger in Flyte. |
| [`flyte.Device`](../packages/flyte/device) |Represents a device type, its quantity and partition if applicable. |
| [`flyte.Environment`](../packages/flyte/environment) | |
| [`flyte.FixedRate`](../packages/flyte/fixedrate) |This class defines a FixedRate automation that can be associated with a Trigger in Flyte. |
| [`flyte.Image`](../packages/flyte/image) |This is a representation of Container Images, which can be used to create layered images programmatically. |
| [`flyte.PodTemplate`](../packages/flyte/podtemplate) |Custom PodTemplate specification for a Task. |
| [`flyte.Resources`](../packages/flyte/resources) |Resources such as CPU, Memory, and GPU that can be allocated to a task. |
| [`flyte.RetryStrategy`](../packages/flyte/retrystrategy) |Retry strategy for the task or task environment. |
| [`flyte.ReusePolicy`](../packages/flyte/reusepolicy) |ReusePolicy can be used to configure a task to reuse the environment. |
| [`flyte.Secret`](../packages/flyte/secret) |Secrets are used to inject sensitive information into tasks or image build context. |
| [`flyte.TaskEnvironment`](../packages/flyte/taskenvironment) |Environment class to define a new environment for a set of tasks. |
| [`flyte.Timeout`](../packages/flyte/timeout) |Timeout class to define a timeout for a task. |
| [`flyte.Trigger`](../packages/flyte/trigger) |This class defines specification of a Trigger, that can be associated with any Flyte V2 task. |
| [`flyte.app.AppEndpoint`](../packages/flyte.app/appendpoint) |Embed an upstream app's endpoint as an app input. |
| [`flyte.app.AppEnvironment`](../packages/flyte.app/appenvironment) | |
| [`flyte.app.Domain`](../packages/flyte.app/domain) |Subdomain to use for the domain. |
| [`flyte.app.Input`](../packages/flyte.app/input) |Input for application. |
| [`flyte.app.Link`](../packages/flyte.app/link) |Custom links to add to the app. |
| [`flyte.app.Port`](../packages/flyte.app/port) | |
| [`flyte.app.RunOutput`](../packages/flyte.app/runoutput) |Use a run's output for app inputs. |
| [`flyte.app.Scaling`](../packages/flyte.app/scaling) | |
| [`flyte.config.Config`](../packages/flyte.config/config) |This the parent configuration object and holds all the underlying configuration object types. |
| [`flyte.errors.ActionNotFoundError`](../packages/flyte.errors/actionnotfounderror) |This error is raised when the user tries to access an action that does not exist. |
| [`flyte.errors.BaseRuntimeError`](../packages/flyte.errors/baseruntimeerror) |Base class for all Union runtime errors. |
| [`flyte.errors.CustomError`](../packages/flyte.errors/customerror) |This error is raised when the user raises a custom error. |
| [`flyte.errors.DeploymentError`](../packages/flyte.errors/deploymenterror) |This error is raised when the deployment of a task fails, or some preconditions for deployment are not met. |
| [`flyte.errors.ImageBuildError`](../packages/flyte.errors/imagebuilderror) |This error is raised when the image build fails. |
| [`flyte.errors.ImagePullBackOffError`](../packages/flyte.errors/imagepullbackofferror) |This error is raised when the image cannot be pulled. |
| [`flyte.errors.InitializationError`](../packages/flyte.errors/initializationerror) |This error is raised when the Union system is tried to access without being initialized. |
| [`flyte.errors.InlineIOMaxBytesBreached`](../packages/flyte.errors/inlineiomaxbytesbreached) |This error is raised when the inline IO max bytes limit is breached. |
| [`flyte.errors.InvalidImageNameError`](../packages/flyte.errors/invalidimagenameerror) |This error is raised when the image name is invalid. |
| [`flyte.errors.LogsNotYetAvailableError`](../packages/flyte.errors/logsnotyetavailableerror) |This error is raised when the logs are not yet available for a task. |
| [`flyte.errors.ModuleLoadError`](../packages/flyte.errors/moduleloaderror) |This error is raised when the module cannot be loaded, either because it does not exist or because of a. |
| [`flyte.errors.NotInTaskContextError`](../packages/flyte.errors/notintaskcontexterror) |This error is raised when the user tries to access the task context outside of a task. |
| [`flyte.errors.OOMError`](../packages/flyte.errors/oomerror) |This error is raised when the underlying task execution fails because of an out-of-memory error. |
| [`flyte.errors.OnlyAsyncIOSupportedError`](../packages/flyte.errors/onlyasynciosupportederror) |This error is raised when the user tries to use sync IO in an async task. |
| [`flyte.errors.PrimaryContainerNotFoundError`](../packages/flyte.errors/primarycontainernotfounderror) |This error is raised when the primary container is not found. |
| [`flyte.errors.ReferenceTaskError`](../packages/flyte.errors/referencetaskerror) |This error is raised when the user tries to access a task that does not exist. |
| [`flyte.errors.RetriesExhaustedError`](../packages/flyte.errors/retriesexhaustederror) |This error is raised when the underlying task execution fails after all retries have been exhausted. |
| [`flyte.errors.RunAbortedError`](../packages/flyte.errors/runabortederror) |This error is raised when the run is aborted by the user. |
| [`flyte.errors.RuntimeDataValidationError`](../packages/flyte.errors/runtimedatavalidationerror) |This error is raised when the user tries to access a resource that does not exist or is invalid. |
| [`flyte.errors.RuntimeSystemError`](../packages/flyte.errors/runtimesystemerror) |This error is raised when the underlying task execution fails because of a system error. |
| [`flyte.errors.RuntimeUnknownError`](../packages/flyte.errors/runtimeunknownerror) |This error is raised when the underlying task execution fails because of an unknown error. |
| [`flyte.errors.RuntimeUserError`](../packages/flyte.errors/runtimeusererror) |This error is raised when the underlying task execution fails because of an error in the user's code. |
| [`flyte.errors.SlowDownError`](../packages/flyte.errors/slowdownerror) |This error is raised when the user tries to access a resource that does not exist or is invalid. |
| [`flyte.errors.TaskInterruptedError`](../packages/flyte.errors/taskinterruptederror) |This error is raised when the underlying task execution is interrupted. |
| [`flyte.errors.TaskTimeoutError`](../packages/flyte.errors/tasktimeouterror) |This error is raised when the underlying task execution runs for longer than the specified timeout. |
| [`flyte.errors.UnionRpcError`](../packages/flyte.errors/unionrpcerror) |This error is raised when communication with the Union server fails. |
| [`flyte.extend.AsyncFunctionTaskTemplate`](../packages/flyte.extend/asyncfunctiontasktemplate) |A task template that wraps an asynchronous functions. |
| [`flyte.extend.ImageBuildEngine`](../packages/flyte.extend/imagebuildengine) |ImageBuildEngine contains a list of builders that can be used to build an ImageSpec. |
| [`flyte.extend.TaskTemplate`](../packages/flyte.extend/tasktemplate) |Task template is a template for a task that can be executed. |
| [`flyte.extras.ContainerTask`](../packages/flyte.extras/containertask) |This is an intermediate class that represents Flyte Tasks that run a container at execution time. |
| [`flyte.git.GitStatus`](../packages/flyte.git/gitstatus) |A class representing the status of a git repository. |
| [`flyte.io.DataFrame`](../packages/flyte.io/dataframe) |This is the user facing DataFrame class. |
| [`flyte.io.DataFrameDecoder`](../packages/flyte.io/dataframedecoder) |Helper class that provides a standard way to create an ABC using. |
| [`flyte.io.DataFrameEncoder`](../packages/flyte.io/dataframeencoder) |Helper class that provides a standard way to create an ABC using. |
| [`flyte.io.DataFrameTransformerEngine`](../packages/flyte.io/dataframetransformerengine) |Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. |
| [`flyte.io.Dir`](../packages/flyte.io/dir) |A generic directory class representing a directory with files of a specified format. |
| [`flyte.io.File`](../packages/flyte.io/file) |A generic file class representing a file with a specified format. |
| [`flyte.models.ActionID`](../packages/flyte.models/actionid) |A class representing the ID of an Action, nested within a Run. |
| [`flyte.models.Checkpoints`](../packages/flyte.models/checkpoints) |A class representing the checkpoints for a task. |
| [`flyte.models.CodeBundle`](../packages/flyte.models/codebundle) |A class representing a code bundle for a task. |
| [`flyte.models.GroupData`](../packages/flyte.models/groupdata) | |
| [`flyte.models.NativeInterface`](../packages/flyte.models/nativeinterface) |A class representing the native interface for a task. |
| [`flyte.models.PathRewrite`](../packages/flyte.models/pathrewrite) |Configuration for rewriting paths during input loading. |
| [`flyte.models.RawDataPath`](../packages/flyte.models/rawdatapath) |A class representing the raw data path for a task. |
| [`flyte.models.SerializationContext`](../packages/flyte.models/serializationcontext) |This object holds serialization time contextual information, that can be used when serializing the task and. |
| [`flyte.models.TaskContext`](../packages/flyte.models/taskcontext) |A context class to hold the current task executions context. |
| [`flyte.remote.Action`](../packages/flyte.remote/action) |A class representing an action. |
| [`flyte.remote.ActionDetails`](../packages/flyte.remote/actiondetails) |A class representing an action. |
| [`flyte.remote.ActionInputs`](../packages/flyte.remote/actioninputs) |A class representing the inputs of an action. |
| [`flyte.remote.ActionOutputs`](../packages/flyte.remote/actionoutputs) |A class representing the outputs of an action. |
| [`flyte.remote.App`](../packages/flyte.remote/app) |A mixin class that provides a method to convert an object to a JSON-serializable dictionary. |
| [`flyte.remote.Project`](../packages/flyte.remote/project) |A class representing a project in the Union API. |
| [`flyte.remote.Run`](../packages/flyte.remote/run) |A class representing a run of a task. |
| [`flyte.remote.RunDetails`](../packages/flyte.remote/rundetails) |A class representing a run of a task. |
| [`flyte.remote.Secret`](../packages/flyte.remote/secret) | |
| [`flyte.remote.Task`](../packages/flyte.remote/task) | |
| [`flyte.remote.TaskDetails`](../packages/flyte.remote/taskdetails) | |
| [`flyte.remote.Trigger`](../packages/flyte.remote/trigger) | |
| [`flyte.remote.User`](../packages/flyte.remote/user) | |
| [`flyte.report.Report`](../packages/flyte.report/report) | |
| [`flyte.storage.ABFS`](../packages/flyte.storage/abfs) |Any Azure Blob Storage specific configuration. |
| [`flyte.storage.GCS`](../packages/flyte.storage/gcs) |Any GCS specific configuration. |
| [`flyte.storage.S3`](../packages/flyte.storage/s3) |S3 specific configuration. |
| [`flyte.storage.Storage`](../packages/flyte.storage/storage) |Data storage configuration that applies across any provider. |
| [`flyte.syncify.Syncify`](../packages/flyte.syncify/syncify) |A decorator to convert asynchronous functions or methods into synchronous ones. |
| [`flyte.types.FlytePickle`](../packages/flyte.types/flytepickle) |This type is only used by flytekit internally. |
| [`flyte.types.TypeEngine`](../packages/flyte.types/typeengine) |Core Extensible TypeEngine of Flytekit. |
| [`flyte.types.TypeTransformer`](../packages/flyte.types/typetransformer) |Base transformer type that should be implemented for every python native type that can be handled by flytekit. |
| [`flyte.types.TypeTransformerFailedError`](../packages/flyte.types/typetransformerfailederror) |Inappropriate argument type. |
# Protocols
| Protocol | Description |
|-|-|
| [`flyte.CachePolicy`](../packages/flyte/cachepolicy) |Base class for protocol classes. |
| [`flyte.types.Renderable`](../packages/flyte.types/renderable) |Base class for protocol classes. |
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages ===
# Packages
| Package | Description |
|-|-|
| **Flyte SDK > Packages > flyte** | Flyte SDK for authoring compound AI applications, services and workflows. |
| **Flyte SDK > Packages > flyte.app** | |
| **Flyte SDK > Packages > flyte.config** | |
| **Flyte SDK > Packages > flyte.errors** | Exceptions raised by Union. |
| **Flyte SDK > Packages > flyte.extend** | |
| **Flyte SDK > Packages > flyte.extras** | |
| **Flyte SDK > Packages > flyte.git** | |
| **Flyte SDK > Packages > flyte.io** | ## IO data types. |
| **Flyte SDK > Packages > flyte.models** | |
| **Flyte SDK > Packages > flyte.remote** | Remote Entities that are accessible from the Union Server once deployed or created. |
| **Flyte SDK > Packages > flyte.report** | |
| **Flyte SDK > Packages > flyte.storage** | |
| **Flyte SDK > Packages > flyte.syncify** | # Syncify Module. |
| **Flyte SDK > Packages > flyte.types** | # Flyte Type System. |
## Subpages
- **Flyte SDK > Packages > flyte**
- **Flyte SDK > Packages > flyte.app**
- **Flyte SDK > Packages > flyte.config**
- **Flyte SDK > Packages > flyte.errors**
- **Flyte SDK > Packages > flyte.extend**
- **Flyte SDK > Packages > flyte.extras**
- **Flyte SDK > Packages > flyte.git**
- **Flyte SDK > Packages > flyte.io**
- **Flyte SDK > Packages > flyte.models**
- **Flyte SDK > Packages > flyte.remote**
- **Flyte SDK > Packages > flyte.report**
- **Flyte SDK > Packages > flyte.storage**
- **Flyte SDK > Packages > flyte.syncify**
- **Flyte SDK > Packages > flyte.types**
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte ===
# flyte
Flyte SDK for authoring compound AI applications, services and workflows.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte > `Cache`** | Cache configuration for a task. |
| **Flyte SDK > Packages > flyte > `Cron`** | This class defines a Cron automation that can be associated with a Trigger in Flyte. |
| **Flyte SDK > Packages > flyte > `Device`** | Represents a device type, its quantity and partition if applicable. |
| **Flyte SDK > Packages > flyte > `Environment`** | |
| **Flyte SDK > Packages > flyte > `FixedRate`** | This class defines a FixedRate automation that can be associated with a Trigger in Flyte. |
| **Flyte SDK > Packages > flyte > `Image`** | This is a representation of Container Images, which can be used to create layered images programmatically. |
| **Flyte SDK > Packages > flyte > `PodTemplate`** | Custom PodTemplate specification for a Task. |
| **Flyte SDK > Packages > flyte > `Resources`** | Resources such as CPU, Memory, and GPU that can be allocated to a task. |
| **Flyte SDK > Packages > flyte > `RetryStrategy`** | Retry strategy for the task or task environment. |
| **Flyte SDK > Packages > flyte > `ReusePolicy`** | ReusePolicy can be used to configure a task to reuse the environment. |
| **Flyte SDK > Packages > flyte > `Secret`** | Secrets are used to inject sensitive information into tasks or image build context. |
| **Flyte SDK > Packages > flyte > `TaskEnvironment`** | Environment class to define a new environment for a set of tasks. |
| **Flyte SDK > Packages > flyte > `Timeout`** | Timeout class to define a timeout for a task. |
| **Flyte SDK > Packages > flyte > `Trigger`** | This class defines specification of a Trigger, that can be associated with any Flyte V2 task. |
### Protocols
| Protocol | Description |
|-|-|
| **Flyte SDK > Packages > flyte > `CachePolicy`** | Base class for protocol classes. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte > `AMD_GPU()`** | Create an AMD GPU device instance. |
| **Flyte SDK > Packages > flyte > Methods > GPU()** | Create a GPU device instance. |
| **Flyte SDK > Packages > flyte > `HABANA_GAUDI()`** | Create a Habana Gaudi device instance. |
| **Flyte SDK > Packages > flyte > Methods > Neuron()** | Create a Neuron device instance. |
| **Flyte SDK > Packages > flyte > Methods > TPU()** | Create a TPU device instance. |
| **Flyte SDK > Packages > flyte > Methods > build()** | Build an image. |
| **Flyte SDK > Packages > flyte > `build_images()`** | Build the images for the given environments. |
| **Flyte SDK > Packages > flyte > Methods > ctx()** | Returns flyte. |
| **Flyte SDK > Packages > flyte > `current_domain()`** | Returns the current domain from Runtime environment (on the cluster) or from the initialized configuration. |
| **Flyte SDK > Packages > flyte > `custom_context()`** | Synchronous context manager to set input context for tasks spawned within this block. |
| **Flyte SDK > Packages > flyte > Methods > deploy()** | Deploy the given environment or list of environments. |
| **Flyte SDK > Packages > flyte > `get_custom_context()`** | Get the current input context. |
| **Flyte SDK > Packages > flyte > Methods > group()** | Create a new group with the given name. |
| **Flyte SDK > Packages > flyte > Methods > init()** | Initialize the Flyte system with the given configuration. |
| **Flyte SDK > Packages > flyte > `init_from_config()`** | Initialize the Flyte system using a configuration file or Config object. |
| **Flyte SDK > Packages > flyte > `init_in_cluster()`** | |
| **Flyte SDK > Packages > flyte > Methods > map()** | Map a function over the provided arguments with concurrent execution. |
| **Flyte SDK > Packages > flyte > Methods > run()** | Run a task with the given parameters. |
| **Flyte SDK > Packages > flyte > Methods > serve()** | Serve a Flyte app using an AppEnvironment. |
| **Flyte SDK > Packages > flyte > trace()** | A decorator that traces function execution with timing information. |
| **Flyte SDK > Packages > flyte > version()** | Returns the version of the Flyte SDK. |
| **Flyte SDK > Packages > flyte > `with_runcontext()`** | Launch a new run with the given parameters as the context. |
| **Flyte SDK > Packages > flyte > `with_servecontext()`** | Create a serve context with custom configuration. |
### Variables
| Property | Type | Description |
|-|-|-|
| `TimeoutType` | `UnionType` | |
| `TriggerTime` | `_trigger_time` | |
| `__version__` | `str` | |
| `logger` | `Logger` | |
## Methods
#### AMD_GPU()
```python
def AMD_GPU(
device: typing.Literal['MI100', 'MI210', 'MI250', 'MI250X', 'MI300A', 'MI300X', 'MI325X', 'MI350X', 'MI355X'],
) -> flyte._resources.Device
```
Create an AMD GPU device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['MI100', 'MI210', 'MI250', 'MI250X', 'MI300A', 'MI300X', 'MI325X', 'MI350X', 'MI355X']` | Device type (e.g., "MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X", "MI355X"). :return: Device instance. |
#### GPU()
```python
def GPU(
device: typing.Literal['A10', 'A10G', 'A100', 'A100 80G', 'B200', 'H100', 'H200', 'L4', 'L40s', 'T4', 'V100', 'RTX PRO 6000'],
quantity: typing.Literal[1, 2, 3, 4, 5, 6, 7, 8],
partition: typing.Union[typing.Literal['1g.5gb', '2g.10gb', '3g.20gb', '4g.20gb', '7g.40gb'], typing.Literal['1g.10gb', '2g.20gb', '3g.40gb', '4g.40gb', '7g.80gb'], typing.Literal['1g.18gb', '1g.35gb', '2g.35gb', '3g.71gb', '4g.71gb', '7g.141gb'], NoneType],
) -> flyte._resources.Device
```
Create a GPU device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['A10', 'A10G', 'A100', 'A100 80G', 'B200', 'H100', 'H200', 'L4', 'L40s', 'T4', 'V100', 'RTX PRO 6000']` | The type of GPU (e.g., "T4", "A100"). |
| `quantity` | `typing.Literal[1, 2, 3, 4, 5, 6, 7, 8]` | The number of GPUs of this type. |
| `partition` | `typing.Union[typing.Literal['1g.5gb', '2g.10gb', '3g.20gb', '4g.20gb', '7g.40gb'], typing.Literal['1g.10gb', '2g.20gb', '3g.40gb', '4g.40gb', '7g.80gb'], typing.Literal['1g.18gb', '1g.35gb', '2g.35gb', '3g.71gb', '4g.71gb', '7g.141gb'], NoneType]` | The partition of the GPU (e.g., "1g.5gb", "2g.10gb" for gpus) or ("1x1", ... for tpus). :return: Device instance. |
#### HABANA_GAUDI()
```python
def HABANA_GAUDI(
device: typing.Literal['Gaudi1'],
) -> flyte._resources.Device
```
Create a Habana Gaudi device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['Gaudi1']` | Device type (e.g., "DL1"). :return: Device instance. |
#### Neuron()
```python
def Neuron(
device: typing.Literal['Inf1', 'Inf2', 'Trn1', 'Trn1n', 'Trn2', 'Trn2u'],
) -> flyte._resources.Device
```
Create a Neuron device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['Inf1', 'Inf2', 'Trn1', 'Trn1n', 'Trn2', 'Trn2u']` | Device type (e.g., "Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u"). |
#### TPU()
```python
def TPU(
device: typing.Literal['V5P', 'V6E'],
partition: typing.Union[typing.Literal['2x2x1', '2x2x2', '2x4x4', '4x4x4', '4x4x8', '4x8x8', '8x8x8', '8x8x16', '8x16x16', '16x16x16', '16x16x24'], typing.Literal['1x1', '2x2', '2x4', '4x4', '4x8', '8x8', '8x16', '16x16'], NoneType],
)
```
Create a TPU device instance.
| Parameter | Type | Description |
|-|-|-|
| `device` | `typing.Literal['V5P', 'V6E']` | Device type (e.g., "V5P", "V6E"). |
| `partition` | `typing.Union[typing.Literal['2x2x1', '2x2x2', '2x4x4', '4x4x4', '4x4x8', '4x8x8', '8x8x8', '8x8x16', '8x16x16', '16x16x16', '16x16x24'], typing.Literal['1x1', '2x2', '2x4', '4x4', '4x8', '8x8', '8x16', '16x16'], NoneType]` | Partition of the TPU (e.g., "1x1", "2x2", ...). :return: Device instance. |
#### build()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await build.aio()`.
```python
def build(
image: Image,
) -> str
```
Build an image. The existing async context will be used.
Example:
```
import flyte
image = flyte.Image("example_image")
if __name__ == "__main__":
asyncio.run(flyte.build.aio(image))
```
| Parameter | Type | Description |
|-|-|-|
| `image` | `Image` | The image(s) to build. :return: The image URI. |
#### build_images()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await build_images.aio()`.
```python
def build_images(
envs: Environment,
) -> ImageCache
```
Build the images for the given environments.
| Parameter | Type | Description |
|-|-|-|
| `envs` | `Environment` | Environment to build images for. :return: ImageCache containing the built images. |
#### ctx()
```python
def ctx()
```
Returns flyte.models.TaskContext if within a task context, else None
Note: Only use this in task code and not module level.
#### current_domain()
```python
def current_domain()
```
Returns the current domain from Runtime environment (on the cluster) or from the initialized configuration.
This is safe to be used during `deploy`, `run` and within `task` code.
NOTE: This will not work if you deploy a task to a domain and then run it in another domain.
Raises InitializationError if the configuration is not initialized or domain is not set.
:return: The current domain
#### custom_context()
```python
def custom_context(
context: str,
)
```
Synchronous context manager to set input context for tasks spawned within this block.
Example:
```python
import flyte
env = flyte.TaskEnvironment(name="...")
@env.task
def t1():
ctx = flyte.get_custom_context()
print(ctx)
@env.task
def main():
# context can be passed via a context manager
with flyte.custom_context(project="my-project"):
t1() # will have {'project': 'my-project'} as context
```
| Parameter | Type | Description |
|-|-|-|
| `context` | `str` | Key-value pairs to set as input context |
#### deploy()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await deploy.aio()`.
```python
def deploy(
envs: Environment,
dryrun: bool,
version: str | None,
interactive_mode: bool | None,
copy_style: CopyFiles,
) -> List[Deployment]
```
Deploy the given environment or list of environments.
| Parameter | Type | Description |
|-|-|-|
| `envs` | `Environment` | Environment or list of environments to deploy. |
| `dryrun` | `bool` | dryrun mode, if True, the deployment will not be applied to the control plane. |
| `version` | `str \| None` | version of the deployment, if None, the version will be computed from the code bundle. TODO: Support for interactive_mode |
| `interactive_mode` | `bool \| None` | Optional, can be forced to True or False. If not provided, it will be set based on the current environment. For example Jupyter notebooks are considered interactive mode, while scripts are not. This is used to determine how the code bundle is created. |
| `copy_style` | `CopyFiles` | Copy style to use when running the task :return: Deployment object containing the deployed environments and tasks. |
#### get_custom_context()
```python
def get_custom_context()
```
Get the current input context. This can be used within a task to retrieve
context metadata that was passed to the action.
Context will automatically propagate to sub-actions.
Example:
```python
import flyte
env = flyte.TaskEnvironment(name="...")
@env.task
def t1():
# context can be retrieved with `get_custom_context`
ctx = flyte.get_custom_context()
print(ctx) # {'project': '...', 'entity': '...'}
```
:return: Dictionary of context key-value pairs
#### group()
```python
def group(
name: str,
)
```
Create a new group with the given name. The method is intended to be used as a context manager.
Example:
```python
@task
async def my_task():
...
with group("my_group"):
t1(x,y) # tasks in this block will be grouped under "my_group"
...
```
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the group |
#### init()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init.aio()`.
```python
def init(
org: str | None,
project: str | None,
domain: str | None,
root_dir: Path | None,
log_level: int | None,
log_format: LogFormat | None,
endpoint: str | None,
headless: bool,
insecure: bool,
insecure_skip_verify: bool,
ca_cert_file_path: str | None,
auth_type: AuthType,
command: List[str] | None,
proxy_command: List[str] | None,
api_key: str | None,
client_id: str | None,
client_credentials_secret: str | None,
auth_client_config: ClientConfig | None,
rpc_retries: int,
http_proxy_url: str | None,
storage: Storage | None,
batch_size: int,
image_builder: ImageBuildEngine.ImageBuilderType,
images: typing.Dict[str, str] | None,
source_config_path: Optional[Path],
sync_local_sys_paths: bool,
load_plugin_type_transformers: bool,
)
```
Initialize the Flyte system with the given configuration. This method should be called before any other Flyte
remote API methods are called. Thread-safe implementation.
| Parameter | Type | Description |
|-|-|-|
| `org` | `str \| None` | Optional organization override for the client. Should be set by auth instead. |
| `project` | `str \| None` | Optional project name (not used in this implementation) |
| `domain` | `str \| None` | Optional domain name (not used in this implementation) |
| `root_dir` | `Path \| None` | Optional root directory from which to determine how to load files, and find paths to files. This is useful for determining the root directory for the current project, and for locating files like config etc. also use to determine all the code that needs to be copied to the remote location. defaults to the editable install directory if the cwd is in a Python editable install, else just the cwd. |
| `log_level` | `int \| None` | Optional logging level for the logger, default is set using the default initialization policies |
| `log_format` | `LogFormat \| None` | Optional logging format for the logger, default is "console" |
| `endpoint` | `str \| None` | Optional API endpoint URL |
| `headless` | `bool` | Optional Whether to run in headless mode |
| `insecure` | `bool` | insecure flag for the client |
| `insecure_skip_verify` | `bool` | Whether to skip SSL certificate verification |
| `ca_cert_file_path` | `str \| None` | [optional] str Root Cert to be loaded and used to verify admin |
| `auth_type` | `AuthType` | The authentication type to use (Pkce, ClientSecret, ExternalCommand, DeviceFlow) |
| `command` | `List[str] \| None` | This command is executed to return a token using an external process |
| `proxy_command` | `List[str] \| None` | This command is executed to return a token for proxy authorization using an external process |
| `api_key` | `str \| None` | Optional API key for authentication |
| `client_id` | `str \| None` | This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. |
| `client_credentials_secret` | `str \| None` | Used for service auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the password directly from the environment variable. Note that this is less secure! Please only use this if mounting the secret as a file is impossible |
| `auth_client_config` | `ClientConfig \| None` | Optional client configuration for authentication |
| `rpc_retries` | `int` | [optional] int Number of times to retry the platform calls |
| `http_proxy_url` | `str \| None` | [optional] HTTP Proxy to be used for OAuth requests |
| `storage` | `Storage \| None` | Optional blob store (S3, GCS, Azure) configuration if needed to access (i.e. using Minio) |
| `batch_size` | `int` | Optional batch size for operations that use listings, defaults to 1000, so limit larger than batch_size will be split into multiple requests. |
| `image_builder` | `ImageBuildEngine.ImageBuilderType` | Optional image builder configuration, if not provided, the default image builder will be used. |
| `images` | `typing.Dict[str, str] \| None` | Optional dict of images that can be used by referencing the image name. |
| `source_config_path` | `Optional[Path]` | Optional path to the source configuration file (This is only used for documentation) |
| `sync_local_sys_paths` | `bool` | Whether to include and synchronize local sys.path entries under the root directory into the remote container (default: True). |
| `load_plugin_type_transformers` | `bool` | If enabled (default True), load the type transformer plugins registered under the "flyte.plugins.types" entry point group. :return: None |
#### init_from_config()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init_from_config.aio()`.
```python
def init_from_config(
path_or_config: str | Path | Config | None,
root_dir: Path | None,
log_level: int | None,
log_format: LogFormat,
storage: Storage | None,
images: tuple[str, ...] | None,
sync_local_sys_paths: bool,
)
```
Initialize the Flyte system using a configuration file or Config object. This method should be called before any
other Flyte remote API methods are called. Thread-safe implementation.
| Parameter | Type | Description |
|-|-|-|
| `path_or_config` | `str \| Path \| Config \| None` | Path to the configuration file or Config object |
| `root_dir` | `Path \| None` | Optional root directory from which to determine how to load files, and find paths to files like config etc. For example if one uses the copy-style=="all", it is essential to determine the root directory for the current project. If not provided, it defaults to the editable install directory or if not available, the current working directory. |
| `log_level` | `int \| None` | Optional logging level for the framework logger, default is set using the default initialization policies |
| `log_format` | `LogFormat` | Optional logging format for the logger, default is "console" |
| `storage` | `Storage \| None` | Optional blob store (S3, GCS, Azure) configuration if needed to access (i.e. using Minio) |
| `images` | `tuple[str, ...] \| None` | List of image strings in format "imagename=imageuri" or just "imageuri". |
| `sync_local_sys_paths` | `bool` | Whether to include and synchronize local sys.path entries under the root directory into the remote container (default: True). :return: None |
#### init_in_cluster()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await init_in_cluster.aio()`.
```python
def init_in_cluster(
org: str | None,
project: str | None,
domain: str | None,
api_key: str | None,
endpoint: str | None,
insecure: bool,
) -> dict[str, typing.Any]
```
| Parameter | Type | Description |
|-|-|-|
| `org` | `str \| None` | |
| `project` | `str \| None` | |
| `domain` | `str \| None` | |
| `api_key` | `str \| None` | |
| `endpoint` | `str \| None` | |
| `insecure` | `bool` | |
#### map()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await flyte.map.aio()`.
```python
def map(
func: typing.Union[flyte._task.AsyncFunctionTaskTemplate[~P, ~R, ~F], functools.partial[~R]],
args: *args,
group_name: str | None,
concurrency: int,
return_exceptions: bool,
) -> typing.Iterator[typing.Union[~R, Exception]]
```
Map a function over the provided arguments with concurrent execution.
| Parameter | Type | Description |
|-|-|-|
| `func` | `typing.Union[flyte._task.AsyncFunctionTaskTemplate[~P, ~R, ~F], functools.partial[~R]]` | The async function to map. |
| `args` | `*args` | Positional arguments to pass to the function (iterables that will be zipped). |
| `group_name` | `str \| None` | The name of the group for the mapped tasks. |
| `concurrency` | `int` | The maximum number of concurrent tasks to run. If 0, run all tasks concurrently. |
| `return_exceptions` | `bool` | If True, yield exceptions instead of raising them. :return: AsyncIterator yielding results in order. |
#### run()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await run.aio()`.
```python
def run(
task: TaskTemplate[P, R, F],
args: *args,
kwargs: **kwargs,
) -> Run
```
Run a task with the given parameters
| Parameter | Type | Description |
|-|-|-|
| `task` | `TaskTemplate[P, R, F]` | task to run |
| `args` | `*args` | args to pass to the task |
| `kwargs` | `**kwargs` | kwargs to pass to the task :return: Run \| Result of the task |
#### serve()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await serve.aio()`.
```python
def serve(
app_env: 'AppEnvironment',
) -> 'App'
```
Serve a Flyte app using an AppEnvironment.
This is the simple, direct way to serve an app. For more control over
deployment settings (env vars, cluster pool, etc.), use with_servecontext().
Example:
```python
import flyte
from flyte.app.extras import FastAPIAppEnvironment
env = FastAPIAppEnvironment(name="my-app", ...)
# Simple serve
app = flyte.serve(env)
print(f"App URL: {app.url}")
```
| Parameter | Type | Description |
|-|-|-|
| `app_env` | `'AppEnvironment'` | The app environment to serve |
#### trace()
```python
def trace(
func: typing.Callable[..., ~T],
) -> typing.Callable[..., ~T]
```
A decorator that traces function execution with timing information.
Works with regular functions, async functions, and async generators/iterators.
| Parameter | Type | Description |
|-|-|-|
| `func` | `typing.Callable[..., ~T]` | |
#### version()
```python
def version()
```
Returns the version of the Flyte SDK.
#### with_runcontext()
```python
def with_runcontext(
mode: Mode | None,
name: Optional[str],
service_account: Optional[str],
version: Optional[str],
copy_style: CopyFiles,
dry_run: bool,
copy_bundle_to: pathlib.Path | None,
interactive_mode: bool | None,
raw_data_path: str | None,
run_base_dir: str | None,
overwrite_cache: bool,
project: str | None,
domain: str | None,
env_vars: Dict[str, str] | None,
labels: Dict[str, str] | None,
annotations: Dict[str, str] | None,
interruptible: bool | None,
log_level: int | None,
log_format: LogFormat,
disable_run_cache: bool,
queue: Optional[str],
custom_context: Dict[str, str] | None,
cache_lookup_scope: CacheLookupScope,
) -> _Runner
```
Launch a new run with the given parameters as the context.
Example:
```python
import flyte
env = flyte.TaskEnvironment("example")
@env.task
async def example_task(x: int, y: str) -> str:
return f"{x} {y}"
if __name__ == "__main__":
flyte.with_runcontext(name="example_run_id").run(example_task, 1, y="hello")
```
| Parameter | Type | Description |
|-|-|-|
| `mode` | `Mode \| None` | Optional The mode to use for the run, if not provided, it will be computed from flyte.init |
| `name` | `Optional[str]` | Optional The name to use for the run |
| `service_account` | `Optional[str]` | Optional The service account to use for the run context |
| `version` | `Optional[str]` | Optional The version to use for the run, if not provided, it will be computed from the code bundle |
| `copy_style` | `CopyFiles` | Optional The copy style to use for the run context |
| `dry_run` | `bool` | Optional If true, the run will not be executed, but the bundle will be created |
| `copy_bundle_to` | `pathlib.Path \| None` | When dry_run is True, the bundle will be copied to this location if specified |
| `interactive_mode` | `bool \| None` | Optional, can be forced to True or False. If not provided, it will be set based on the current environment. For example Jupyter notebooks are considered interactive mode, while scripts are not. This is used to determine how the code bundle is created. |
| `raw_data_path` | `str \| None` | Use this path to store the raw data for the run for local and remote, and can be used to store raw data in specific locations. |
| `run_base_dir` | `str \| None` | Optional The base directory to use for the run. This is used to store the metadata for the run, that is passed between tasks. |
| `overwrite_cache` | `bool` | Optional If true, the cache will be overwritten for the run |
| `project` | `str \| None` | Optional The project to use for the run |
| `domain` | `str \| None` | Optional The domain to use for the run |
| `env_vars` | `Dict[str, str] \| None` | Optional Environment variables to set for the run |
| `labels` | `Dict[str, str] \| None` | Optional Labels to set for the run |
| `annotations` | `Dict[str, str] \| None` | Optional Annotations to set for the run |
| `interruptible` | `bool \| None` | Optional If true, the run can be scheduled on interruptible instances and false implies that all tasks in the run should only be scheduled on non-interruptible instances. If not specified the original setting on all tasks is retained. |
| `log_level` | `int \| None` | Optional Log level to set for the run. If not provided, it will be set to the default log level set using `flyte.init()` |
| `log_format` | `LogFormat` | Optional Log format to set for the run. If not provided, it will be set to the default log format |
| `disable_run_cache` | `bool` | Optional If true, the run cache will be disabled. This is useful for testing purposes. |
| `queue` | `Optional[str]` | Optional The queue to use for the run. This is used to specify the cluster to use for the run. |
| `custom_context` | `Dict[str, str] \| None` | Optional global input context to pass to the task. This will be available via get_custom_context() within the task and will automatically propagate to sub-tasks. Acts as base/default values that can be overridden by context managers in the code. |
| `cache_lookup_scope` | `CacheLookupScope` | Optional Scope to use for the run. This is used to specify the scope to use for cache lookups. If not specified, it will be set to the default scope (global unless overridden at the system level). :return: runner |
#### with_servecontext()
```python
def with_servecontext(
version: Optional[str],
copy_style: CopyFiles,
dry_run: bool,
project: str | None,
domain: str | None,
env_vars: dict[str, str] | None,
input_values: dict[str, dict[str, str | flyte.io.File | flyte.io.Dir]] | None,
cluster_pool: str | None,
log_level: int | None,
log_format: LogFormat,
) -> _Serve
```
Create a serve context with custom configuration.
This function allows you to customize how an app is served, including
overriding environment variables, cluster pool, logging, and other deployment settings.
Example:
```python
import logging
import flyte
from flyte.app.extras import FastAPIAppEnvironment
env = FastAPIAppEnvironment(name="my-app", ...)
# Serve with custom env vars, logging, and cluster pool
app = flyte.with_servecontext(
env_vars={"DATABASE_URL": "postgresql://..."},
log_level=logging.DEBUG,
log_format="json",
cluster_pool="gpu-pool",
project="my-project",
domain="development",
).serve(env)
print(f"App URL: {app.url}")
```
| Parameter | Type | Description |
|-|-|-|
| `version` | `Optional[str]` | Optional version override for the app deployment |
| `copy_style` | `CopyFiles` | |
| `dry_run` | `bool` | |
| `project` | `str \| None` | Optional project override |
| `domain` | `str \| None` | Optional domain override |
| `env_vars` | `dict[str, str] \| None` | Optional environment variables to inject/override in the app container |
| `input_values` | `dict[str, dict[str, str \| flyte.io.File \| flyte.io.Dir]] \| None` | Optional input values to inject/override in the app container. Must be a dictionary that maps app environment names to a dictionary of input names to values. |
| `cluster_pool` | `str \| None` | Optional cluster pool to deploy the app to |
| `log_level` | `int \| None` | Optional log level (e.g., logging.DEBUG, logging.INFO). If not provided, uses init config or default |
| `log_format` | `LogFormat` | |
## Subpages
- [Cache](Cache/)
- [CachePolicy](CachePolicy/)
- [Cron](Cron/)
- [Device](Device/)
- [Environment](Environment/)
- [FixedRate](FixedRate/)
- [Image](Image/)
- [PodTemplate](PodTemplate/)
- [Resources](Resources/)
- [RetryStrategy](RetryStrategy/)
- [ReusePolicy](ReusePolicy/)
- [Secret](Secret/)
- [TaskEnvironment](TaskEnvironment/)
- [Timeout](Timeout/)
- [Trigger](Trigger/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.app ===
# flyte.app
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > AppEndpoint** | Embed an upstream app's endpoint as an app input. |
| **Flyte SDK > Packages > flyte.app > `AppEnvironment`** | |
| **Flyte SDK > Packages > flyte.app > `Domain`** | Subdomain to use for the domain. |
| **Flyte SDK > Packages > flyte.app > `Input`** | Input for application. |
| **Flyte SDK > Packages > flyte.app > `Link`** | Custom links to add to the app. |
| **Flyte SDK > Packages > flyte.app > `Port`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput** | Use a run's output for app inputs. |
| **Flyte SDK > Packages > flyte.app > `Scaling`** | |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > `get_input()`** | Get inputs for application or endpoint. |
## Methods
#### get_input()
```python
def get_input(
name: str,
) -> str
```
Get inputs for application or endpoint.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | |
## Subpages
- **Flyte SDK > Packages > flyte.app > AppEndpoint**
- [AppEnvironment](AppEnvironment/)
- [Domain](Domain/)
- [Input](Input/)
- [Link](Link/)
- [Port](Port/)
- **Flyte SDK > Packages > flyte.app > RunOutput**
- [Scaling](Scaling/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.app/appendpoint ===
# AppEndpoint
**Package:** `flyte.app`
Embed an upstream app's endpoint as an app input.
This enables the declaration of an app input dependency on a the endpoint of
an upstream app, given by a specific app name. This gives the app access to
the upstream app's endpoint as a public or private url.
```python
class AppEndpoint(
data: Any,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `data` | `Any` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `check_type()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > construct()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > copy()** | Returns a copy of the model. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > dict()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `from_orm()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > get()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > json()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > materialize()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_construct()`** | Creates a new instance of the `Model` class with validated data. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_copy()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_dump()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_dump_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_json_schema()`** | Generates a JSON schema for a model class. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_parametrized_name()`** | Compute the class name for parametrizations of generic classes. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_post_init()`** | Override this method to perform additional initialization after `__init__` and `model_construct`. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_rebuild()`** | Try to rebuild the pydantic-core schema for the model. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_validate()`** | Validate a pydantic model instance. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_validate_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `model_validate_strings()`** | Validate the given object with string data against the Pydantic model. |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `parse_file()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `parse_obj()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `parse_raw()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > schema()** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `schema_json()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > `update_forward_refs()`** | |
| **Flyte SDK > Packages > flyte.app > AppEndpoint > Methods > validate()** | |
### check_type()
```python
def check_type(
data: typing.Any,
) -> typing.Any
```
| Parameter | Type | Description |
|-|-|-|
| `data` | `typing.Any` | |
### construct()
```python
def construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | |
| `values` | `Any` | |
### copy()
```python
def copy(
include: AbstractSetIntStr | MappingIntStrAny | None,
exclude: AbstractSetIntStr | MappingIntStrAny | None,
update: Dict[str, Any] | None,
deep: bool,
) -> Self
```
Returns a copy of the model.
> [!WARNING] Deprecated
> This method is now deprecated; use `model_copy` instead.
If you need `include` or `exclude`, use:
```python {test="skip" lint="skip"}
data = self.model_dump(include=include, exclude=exclude, round_trip=True)
data = {**data, **(update or {})}
copied = self.model_validate(data)
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to include in the copied model. |
| `exclude` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to exclude in the copied model. |
| `update` | `Dict[str, Any] \| None` | Optional dictionary of field-value pairs to override field values in the copied model. |
| `deep` | `bool` | If True, the values of fields that are Pydantic models will be deep-copied. |
### dict()
```python
def dict(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
### from_orm()
```python
def from_orm(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### get()
```python
def get()
```
### json()
```python
def json(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
encoder: Callable[[Any], Any] | None,
models_as_dict: bool,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
| `encoder` | `Callable[[Any], Any] \| None` | |
| `models_as_dict` | `bool` | |
| `dumps_kwargs` | `Any` | |
### materialize()
```python
def materialize()
```
### model_construct()
```python
def model_construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
Creates a new instance of the `Model` class with validated data.
Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
Default values are respected, but no other validation is performed.
> [!NOTE]
> `model_construct()` generally respects the `model_config.extra` setting on the provided model.
> That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
> and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
> Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
> an error if extra values are passed, but they will be ignored.
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | A set of field names that were originally explicitly set during instantiation. If provided, this is directly used for the [`model_fields_set`][pydantic.BaseModel.model_fields_set] attribute. Otherwise, the field names from the `values` argument will be used. |
| `values` | `Any` | Trusted or pre-validated data dictionary. |
### model_copy()
```python
def model_copy(
update: Mapping[str, Any] | None,
deep: bool,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > AppEndpoint > `model_copy`**
Returns a copy of the model.
> [!NOTE]
> The underlying instance's [`__dict__`][object.__dict__] attribute is copied. This
> might have unexpected side effects if you store anything in it, on top of the model
> fields (e.g. the value of [cached properties][functools.cached_property]).
| Parameter | Type | Description |
|-|-|-|
| `update` | `Mapping[str, Any] \| None` | |
| `deep` | `bool` | Set to `True` to make a deep copy of the model. |
### model_dump()
```python
def model_dump(
mode: Literal['json', 'python'] | str,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> dict[str, Any]
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > AppEndpoint > `model_dump`**
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
| Parameter | Type | Description |
|-|-|-|
| `mode` | `Literal['json', 'python'] \| str` | The mode in which `to_python` should run. If mode is 'json', the output will only contain JSON serializable types. If mode is 'python', the output may contain non-JSON-serializable Python objects. |
| `include` | `IncEx \| None` | A set of fields to include in the output. |
| `exclude` | `IncEx \| None` | A set of fields to exclude from the output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to use the field's alias in the dictionary key if defined. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_dump_json()
```python
def model_dump_json(
indent: int | None,
ensure_ascii: bool,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> str
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > AppEndpoint > `model_dump_json`**
Generates a JSON representation of the model using Pydantic's `to_json` method.
| Parameter | Type | Description |
|-|-|-|
| `indent` | `int \| None` | Indentation to use in the JSON output. If None is passed, the output will be compact. |
| `ensure_ascii` | `bool` | If `True`, the output is guaranteed to have all incoming non-ASCII characters escaped. If `False` (the default), these characters will be output as-is. |
| `include` | `IncEx \| None` | Field(s) to include in the JSON output. |
| `exclude` | `IncEx \| None` | Field(s) to exclude from the JSON output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to serialize using field aliases. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_json_schema()
```python
def model_json_schema(
by_alias: bool,
ref_template: str,
schema_generator: type[GenerateJsonSchema],
mode: JsonSchemaMode,
union_format: Literal['any_of', 'primitive_type_array'],
) -> dict[str, Any]
```
Generates a JSON schema for a model class.
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | Whether to use attribute aliases or not. |
| `ref_template` | `str` | The reference template. - `'any_of'`: Use the [`anyOf`](https://json-schema.org/understanding-json-schema/reference/combining#anyOf) keyword to combine schemas (the default). - `'primitive_type_array'`: Use the [`type`](https://json-schema.org/understanding-json-schema/reference/type) keyword as an array of strings, containing each type of the combination. If any of the schemas is not a primitive type (`string`, `boolean`, `null`, `integer` or `number`) or contains constraints/metadata, falls back to `any_of`. |
| `schema_generator` | `type[GenerateJsonSchema]` | To override the logic used to generate the JSON schema, as a subclass of `GenerateJsonSchema` with your desired modifications |
| `mode` | `JsonSchemaMode` | The mode in which to generate the schema. |
| `union_format` | `Literal['any_of', 'primitive_type_array']` | |
### model_parametrized_name()
```python
def model_parametrized_name(
params: tuple[type[Any], ...],
) -> str
```
Compute the class name for parametrizations of generic classes.
This method can be overridden to achieve a custom naming scheme for generic BaseModels.
| Parameter | Type | Description |
|-|-|-|
| `params` | `tuple[type[Any], ...]` | Tuple of types of the class. Given a generic class `Model` with 2 type variables and a concrete model `Model[str, int]`, the value `(str, int)` would be passed to `params`. |
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
Override this method to perform additional initialization after `__init__` and `model_construct`.
This is useful if you want to do some validation that requires the entire model to be initialized.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | |
### model_rebuild()
```python
def model_rebuild(
force: bool,
raise_errors: bool,
_parent_namespace_depth: int,
_types_namespace: MappingNamespace | None,
) -> bool | None
```
Try to rebuild the pydantic-core schema for the model.
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
the initial attempt to build the schema, and automatic rebuilding fails.
| Parameter | Type | Description |
|-|-|-|
| `force` | `bool` | Whether to force the rebuilding of the model schema, defaults to `False`. |
| `raise_errors` | `bool` | Whether to raise errors, defaults to `True`. |
| `_parent_namespace_depth` | `int` | The depth level of the parent namespace, defaults to 2. |
| `_types_namespace` | `MappingNamespace \| None` | The types namespace, defaults to `None`. |
### model_validate()
```python
def model_validate(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
from_attributes: bool | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate a pydantic model instance.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `from_attributes` | `bool \| None` | Whether to extract data from object attributes. |
| `context` | `Any \| None` | Additional context to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_json()
```python
def model_validate_json(
json_data: str | bytes | bytearray,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > AppEndpoint > JSON Parsing**
Validate the given JSON data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `json_data` | `str \| bytes \| bytearray` | The JSON data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_strings()
```python
def model_validate_strings(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate the given object with string data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object containing string data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### parse_file()
```python
def parse_file(
path: str | Path,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str \| Path` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### parse_obj()
```python
def parse_obj(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### parse_raw()
```python
def parse_raw(
b: str | bytes,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `b` | `str \| bytes` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### schema()
```python
def schema(
by_alias: bool,
ref_template: str,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
### schema_json()
```python
def schema_json(
by_alias: bool,
ref_template: str,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
| `dumps_kwargs` | `Any` | |
### update_forward_refs()
```python
def update_forward_refs(
localns: Any,
)
```
| Parameter | Type | Description |
|-|-|-|
| `localns` | `Any` | |
### validate()
```python
def validate(
value: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `value` | `Any` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `model_extra` | `None` | Get extra fields set during validation. Returns: A dictionary of extra fields, or `None` if `config.extra` is not set to `"allow"`. |
| `model_fields_set` | `None` | Returns the set of fields that have been explicitly set on this model instance. Returns: A set of strings representing the fields that have been set, i.e. that were not filled from defaults. |
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.app/runoutput ===
# RunOutput
**Package:** `flyte.app`
Use a run's output for app inputs.
This enables the declaration of an app input dependency on a the output of
a run, given by a specific run name, or a task name and version. If
`task_auto_version == 'latest'`, the latest version of the task will be used.
If `task_auto_version == 'current'`, the version will be derived from the callee
app or task context.
```python
class RunOutput(
data: Any,
)
```
Create a new model by parsing and validating input data from keyword arguments.
Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be
validated to form a valid model.
`self` is explicitly positional-only to allow `self` as a field name.
| Parameter | Type | Description |
|-|-|-|
| `data` | `Any` | |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.app > RunOutput > `check_type()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > construct()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > copy()** | Returns a copy of the model. |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > dict()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `from_orm()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > get()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > json()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > materialize()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_construct()`** | Creates a new instance of the `Model` class with validated data. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_copy()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_dump()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_dump_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_json_schema()`** | Generates a JSON schema for a model class. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_parametrized_name()`** | Compute the class name for parametrizations of generic classes. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_post_init()`** | Override this method to perform additional initialization after `__init__` and `model_construct`. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_rebuild()`** | Try to rebuild the pydantic-core schema for the model. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_validate()`** | Validate a pydantic model instance. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_validate_json()`** | !!! abstract "Usage Documentation". |
| **Flyte SDK > Packages > flyte.app > RunOutput > `model_validate_strings()`** | Validate the given object with string data against the Pydantic model. |
| **Flyte SDK > Packages > flyte.app > RunOutput > `parse_file()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `parse_obj()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `parse_raw()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > schema()** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `schema_json()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > `update_forward_refs()`** | |
| **Flyte SDK > Packages > flyte.app > RunOutput > Methods > validate()** | |
### check_type()
```python
def check_type(
data: typing.Any,
) -> typing.Any
```
| Parameter | Type | Description |
|-|-|-|
| `data` | `typing.Any` | |
### construct()
```python
def construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | |
| `values` | `Any` | |
### copy()
```python
def copy(
include: AbstractSetIntStr | MappingIntStrAny | None,
exclude: AbstractSetIntStr | MappingIntStrAny | None,
update: Dict[str, Any] | None,
deep: bool,
) -> Self
```
Returns a copy of the model.
> [!WARNING] Deprecated
> This method is now deprecated; use `model_copy` instead.
If you need `include` or `exclude`, use:
```python {test="skip" lint="skip"}
data = self.model_dump(include=include, exclude=exclude, round_trip=True)
data = {**data, **(update or {})}
copied = self.model_validate(data)
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to include in the copied model. |
| `exclude` | `AbstractSetIntStr \| MappingIntStrAny \| None` | Optional set or mapping specifying which fields to exclude in the copied model. |
| `update` | `Dict[str, Any] \| None` | Optional dictionary of field-value pairs to override field values in the copied model. |
| `deep` | `bool` | If True, the values of fields that are Pydantic models will be deep-copied. |
### dict()
```python
def dict(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
### from_orm()
```python
def from_orm(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### get()
```python
def get()
```
### json()
```python
def json(
include: IncEx | None,
exclude: IncEx | None,
by_alias: bool,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
encoder: Callable[[Any], Any] | None,
models_as_dict: bool,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `include` | `IncEx \| None` | |
| `exclude` | `IncEx \| None` | |
| `by_alias` | `bool` | |
| `exclude_unset` | `bool` | |
| `exclude_defaults` | `bool` | |
| `exclude_none` | `bool` | |
| `encoder` | `Callable[[Any], Any] \| None` | |
| `models_as_dict` | `bool` | |
| `dumps_kwargs` | `Any` | |
### materialize()
```python
def materialize()
```
### model_construct()
```python
def model_construct(
_fields_set: set[str] | None,
values: Any,
) -> Self
```
Creates a new instance of the `Model` class with validated data.
Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
Default values are respected, but no other validation is performed.
> [!NOTE]
> `model_construct()` generally respects the `model_config.extra` setting on the provided model.
> That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
> and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
> Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
> an error if extra values are passed, but they will be ignored.
| Parameter | Type | Description |
|-|-|-|
| `_fields_set` | `set[str] \| None` | A set of field names that were originally explicitly set during instantiation. If provided, this is directly used for the [`model_fields_set`][pydantic.BaseModel.model_fields_set] attribute. Otherwise, the field names from the `values` argument will be used. |
| `values` | `Any` | Trusted or pre-validated data dictionary. |
### model_copy()
```python
def model_copy(
update: Mapping[str, Any] | None,
deep: bool,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > RunOutput > `model_copy`**
Returns a copy of the model.
> [!NOTE]
> The underlying instance's [`__dict__`][object.__dict__] attribute is copied. This
> might have unexpected side effects if you store anything in it, on top of the model
> fields (e.g. the value of [cached properties][functools.cached_property]).
| Parameter | Type | Description |
|-|-|-|
| `update` | `Mapping[str, Any] \| None` | |
| `deep` | `bool` | Set to `True` to make a deep copy of the model. |
### model_dump()
```python
def model_dump(
mode: Literal['json', 'python'] | str,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> dict[str, Any]
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > RunOutput > `model_dump`**
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
| Parameter | Type | Description |
|-|-|-|
| `mode` | `Literal['json', 'python'] \| str` | The mode in which `to_python` should run. If mode is 'json', the output will only contain JSON serializable types. If mode is 'python', the output may contain non-JSON-serializable Python objects. |
| `include` | `IncEx \| None` | A set of fields to include in the output. |
| `exclude` | `IncEx \| None` | A set of fields to exclude from the output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to use the field's alias in the dictionary key if defined. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_dump_json()
```python
def model_dump_json(
indent: int | None,
ensure_ascii: bool,
include: IncEx | None,
exclude: IncEx | None,
context: Any | None,
by_alias: bool | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
exclude_computed_fields: bool,
round_trip: bool,
warnings: bool | Literal['none', 'warn', 'error'],
fallback: Callable[[Any], Any] | None,
serialize_as_any: bool,
) -> str
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > RunOutput > `model_dump_json`**
Generates a JSON representation of the model using Pydantic's `to_json` method.
| Parameter | Type | Description |
|-|-|-|
| `indent` | `int \| None` | Indentation to use in the JSON output. If None is passed, the output will be compact. |
| `ensure_ascii` | `bool` | If `True`, the output is guaranteed to have all incoming non-ASCII characters escaped. If `False` (the default), these characters will be output as-is. |
| `include` | `IncEx \| None` | Field(s) to include in the JSON output. |
| `exclude` | `IncEx \| None` | Field(s) to exclude from the JSON output. |
| `context` | `Any \| None` | Additional context to pass to the serializer. |
| `by_alias` | `bool \| None` | Whether to serialize using field aliases. |
| `exclude_unset` | `bool` | Whether to exclude fields that have not been explicitly set. |
| `exclude_defaults` | `bool` | Whether to exclude fields that are set to their default value. |
| `exclude_none` | `bool` | Whether to exclude fields that have a value of `None`. |
| `exclude_computed_fields` | `bool` | Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated `round_trip` parameter instead. |
| `round_trip` | `bool` | If True, dumped values should be valid as input for non-idempotent types such as Json[T]. |
| `warnings` | `bool \| Literal['none', 'warn', 'error']` | How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. |
| `fallback` | `Callable[[Any], Any] \| None` | A function to call when an unknown value is encountered. If not provided, a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. |
| `serialize_as_any` | `bool` | Whether to serialize fields with duck-typing serialization behavior. |
### model_json_schema()
```python
def model_json_schema(
by_alias: bool,
ref_template: str,
schema_generator: type[GenerateJsonSchema],
mode: JsonSchemaMode,
union_format: Literal['any_of', 'primitive_type_array'],
) -> dict[str, Any]
```
Generates a JSON schema for a model class.
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | Whether to use attribute aliases or not. |
| `ref_template` | `str` | The reference template. - `'any_of'`: Use the [`anyOf`](https://json-schema.org/understanding-json-schema/reference/combining#anyOf) keyword to combine schemas (the default). - `'primitive_type_array'`: Use the [`type`](https://json-schema.org/understanding-json-schema/reference/type) keyword as an array of strings, containing each type of the combination. If any of the schemas is not a primitive type (`string`, `boolean`, `null`, `integer` or `number`) or contains constraints/metadata, falls back to `any_of`. |
| `schema_generator` | `type[GenerateJsonSchema]` | To override the logic used to generate the JSON schema, as a subclass of `GenerateJsonSchema` with your desired modifications |
| `mode` | `JsonSchemaMode` | The mode in which to generate the schema. |
| `union_format` | `Literal['any_of', 'primitive_type_array']` | |
### model_parametrized_name()
```python
def model_parametrized_name(
params: tuple[type[Any], ...],
) -> str
```
Compute the class name for parametrizations of generic classes.
This method can be overridden to achieve a custom naming scheme for generic BaseModels.
| Parameter | Type | Description |
|-|-|-|
| `params` | `tuple[type[Any], ...]` | Tuple of types of the class. Given a generic class `Model` with 2 type variables and a concrete model `Model[str, int]`, the value `(str, int)` would be passed to `params`. |
### model_post_init()
```python
def model_post_init(
context: Any,
)
```
Override this method to perform additional initialization after `__init__` and `model_construct`.
This is useful if you want to do some validation that requires the entire model to be initialized.
| Parameter | Type | Description |
|-|-|-|
| `context` | `Any` | |
### model_rebuild()
```python
def model_rebuild(
force: bool,
raise_errors: bool,
_parent_namespace_depth: int,
_types_namespace: MappingNamespace | None,
) -> bool | None
```
Try to rebuild the pydantic-core schema for the model.
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
the initial attempt to build the schema, and automatic rebuilding fails.
| Parameter | Type | Description |
|-|-|-|
| `force` | `bool` | Whether to force the rebuilding of the model schema, defaults to `False`. |
| `raise_errors` | `bool` | Whether to raise errors, defaults to `True`. |
| `_parent_namespace_depth` | `int` | The depth level of the parent namespace, defaults to 2. |
| `_types_namespace` | `MappingNamespace \| None` | The types namespace, defaults to `None`. |
### model_validate()
```python
def model_validate(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
from_attributes: bool | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate a pydantic model instance.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `from_attributes` | `bool \| None` | Whether to extract data from object attributes. |
| `context` | `Any \| None` | Additional context to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_json()
```python
def model_validate_json(
json_data: str | bytes | bytearray,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
!!! abstract "Usage Documentation"
**Flyte SDK > Packages > flyte.app > RunOutput > JSON Parsing**
Validate the given JSON data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `json_data` | `str \| bytes \| bytearray` | The JSON data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### model_validate_strings()
```python
def model_validate_strings(
obj: Any,
strict: bool | None,
extra: ExtraValues | None,
context: Any | None,
by_alias: bool | None,
by_name: bool | None,
) -> Self
```
Validate the given object with string data against the Pydantic model.
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | The object containing string data to validate. |
| `strict` | `bool \| None` | Whether to enforce types strictly. |
| `extra` | `ExtraValues \| None` | Whether to ignore, allow, or forbid extra data during model validation. See the [`extra` configuration value][pydantic.ConfigDict.extra] for details. |
| `context` | `Any \| None` | Extra variables to pass to the validator. |
| `by_alias` | `bool \| None` | Whether to use the field's alias when validating against the provided input data. |
| `by_name` | `bool \| None` | Whether to use the field's name when validating against the provided input data. |
### parse_file()
```python
def parse_file(
path: str | Path,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str \| Path` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### parse_obj()
```python
def parse_obj(
obj: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `obj` | `Any` | |
### parse_raw()
```python
def parse_raw(
b: str | bytes,
content_type: str | None,
encoding: str,
proto: DeprecatedParseProtocol | None,
allow_pickle: bool,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `b` | `str \| bytes` | |
| `content_type` | `str \| None` | |
| `encoding` | `str` | |
| `proto` | `DeprecatedParseProtocol \| None` | |
| `allow_pickle` | `bool` | |
### schema()
```python
def schema(
by_alias: bool,
ref_template: str,
) -> Dict[str, Any]
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
### schema_json()
```python
def schema_json(
by_alias: bool,
ref_template: str,
dumps_kwargs: Any,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `by_alias` | `bool` | |
| `ref_template` | `str` | |
| `dumps_kwargs` | `Any` | |
### update_forward_refs()
```python
def update_forward_refs(
localns: Any,
)
```
| Parameter | Type | Description |
|-|-|-|
| `localns` | `Any` | |
### validate()
```python
def validate(
value: Any,
) -> Self
```
| Parameter | Type | Description |
|-|-|-|
| `value` | `Any` | |
## Properties
| Property | Type | Description |
|-|-|-|
| `model_extra` | `None` | Get extra fields set during validation. Returns: A dictionary of extra fields, or `None` if `config.extra` is not set to `"allow"`. |
| `model_fields_set` | `None` | Returns the set of fields that have been explicitly set on this model instance. Returns: A set of strings representing the fields that have been set, i.e. that were not filled from defaults. |
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.config ===
# flyte.config
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.config > `Config`** | This the parent configuration object and holds all the underlying configuration object types. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.config > Methods > auto()** | Automatically constructs the Config Object. |
| **Flyte SDK > Packages > flyte.config > `set_if_exists()`** | Given a dict ``d`` sets the key ``k`` with value of config ``v``, if the config value ``v`` is set. |
## Methods
#### auto()
```python
def auto(
config_file: typing.Union[str, pathlib.Path, ConfigFile, None],
) -> Config
```
Automatically constructs the Config Object. The order of precedence is as follows
1. If specified, read the config from the provided file path.
2. If not specified, the config file is searched in the default locations.
a. ./config.yaml if it exists (current working directory)
b. ./.flyte/config.yaml if it exists (current working directory)
c. <git_root>/.flyte/config.yaml if it exists
d. `UCTL_CONFIG` environment variable
e. `FLYTECTL_CONFIG` environment variable
f. ~/.union/config.yaml if it exists
g. ~/.flyte/config.yaml if it exists
3. If any value is not found in the config file, the default value is used.
4. For any value there are environment variables that match the config variable names, those will override
| Parameter | Type | Description |
|-|-|-|
| `config_file` | `typing.Union[str, pathlib.Path, ConfigFile, None]` | file path to read the config from, if not specified default locations are searched :return: Config |
#### set_if_exists()
```python
def set_if_exists(
d: dict,
k: str,
val: typing.Any,
) -> dict
```
Given a dict ``d`` sets the key ``k`` with value of config ``v``, if the config value ``v`` is set
and return the updated dictionary.
| Parameter | Type | Description |
|-|-|-|
| `d` | `dict` | |
| `k` | `str` | |
| `val` | `typing.Any` | |
## Subpages
- [Config](Config/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.errors ===
# flyte.errors
Exceptions raised by Union.
These errors are raised when the underlying task execution fails, either because of a user error, system error or an
unknown error.
## Directory
### Errors
| Exception | Description |
|-|-|
| **Flyte SDK > Packages > flyte.errors > `ActionNotFoundError`** | This error is raised when the user tries to access an action that does not exist. |
| **Flyte SDK > Packages > flyte.errors > `BaseRuntimeError`** | Base class for all Union runtime errors. |
| **Flyte SDK > Packages > flyte.errors > `CustomError`** | This error is raised when the user raises a custom error. |
| **Flyte SDK > Packages > flyte.errors > `DeploymentError`** | This error is raised when the deployment of a task fails, or some preconditions for deployment are not met. |
| **Flyte SDK > Packages > flyte.errors > `ImageBuildError`** | This error is raised when the image build fails. |
| **Flyte SDK > Packages > flyte.errors > `ImagePullBackOffError`** | This error is raised when the image cannot be pulled. |
| **Flyte SDK > Packages > flyte.errors > `InitializationError`** | This error is raised when the Union system is tried to access without being initialized. |
| **Flyte SDK > Packages > flyte.errors > `InlineIOMaxBytesBreached`** | This error is raised when the inline IO max bytes limit is breached. |
| **Flyte SDK > Packages > flyte.errors > `InvalidImageNameError`** | This error is raised when the image name is invalid. |
| **Flyte SDK > Packages > flyte.errors > `LogsNotYetAvailableError`** | This error is raised when the logs are not yet available for a task. |
| **Flyte SDK > Packages > flyte.errors > `ModuleLoadError`** | This error is raised when the module cannot be loaded, either because it does not exist or because of a. |
| **Flyte SDK > Packages > flyte.errors > `NotInTaskContextError`** | This error is raised when the user tries to access the task context outside of a task. |
| **Flyte SDK > Packages > flyte.errors > `OOMError`** | This error is raised when the underlying task execution fails because of an out-of-memory error. |
| **Flyte SDK > Packages > flyte.errors > `OnlyAsyncIOSupportedError`** | This error is raised when the user tries to use sync IO in an async task. |
| **Flyte SDK > Packages > flyte.errors > `PrimaryContainerNotFoundError`** | This error is raised when the primary container is not found. |
| **Flyte SDK > Packages > flyte.errors > `ReferenceTaskError`** | This error is raised when the user tries to access a task that does not exist. |
| **Flyte SDK > Packages > flyte.errors > `RetriesExhaustedError`** | This error is raised when the underlying task execution fails after all retries have been exhausted. |
| **Flyte SDK > Packages > flyte.errors > `RunAbortedError`** | This error is raised when the run is aborted by the user. |
| **Flyte SDK > Packages > flyte.errors > `RuntimeDataValidationError`** | This error is raised when the user tries to access a resource that does not exist or is invalid. |
| **Flyte SDK > Packages > flyte.errors > `RuntimeSystemError`** | This error is raised when the underlying task execution fails because of a system error. |
| **Flyte SDK > Packages > flyte.errors > `RuntimeUnknownError`** | This error is raised when the underlying task execution fails because of an unknown error. |
| **Flyte SDK > Packages > flyte.errors > `RuntimeUserError`** | This error is raised when the underlying task execution fails because of an error in the user's code. |
| **Flyte SDK > Packages > flyte.errors > `SlowDownError`** | This error is raised when the user tries to access a resource that does not exist or is invalid. |
| **Flyte SDK > Packages > flyte.errors > `TaskInterruptedError`** | This error is raised when the underlying task execution is interrupted. |
| **Flyte SDK > Packages > flyte.errors > `TaskTimeoutError`** | This error is raised when the underlying task execution runs for longer than the specified timeout. |
| **Flyte SDK > Packages > flyte.errors > `UnionRpcError`** | This error is raised when communication with the Union server fails. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.errors > `silence_grpc_polling_error()`** | Suppress specific gRPC polling errors in the event loop. |
## Methods
#### silence_grpc_polling_error()
```python
def silence_grpc_polling_error(
loop,
context,
)
```
Suppress specific gRPC polling errors in the event loop.
| Parameter | Type | Description |
|-|-|-|
| `loop` | | |
| `context` | | |
## Subpages
- [ActionNotFoundError](ActionNotFoundError/)
- [BaseRuntimeError](BaseRuntimeError/)
- [CustomError](CustomError/)
- [DeploymentError](DeploymentError/)
- [ImageBuildError](ImageBuildError/)
- [ImagePullBackOffError](ImagePullBackOffError/)
- [InitializationError](InitializationError/)
- [InlineIOMaxBytesBreached](InlineIOMaxBytesBreached/)
- [InvalidImageNameError](InvalidImageNameError/)
- [LogsNotYetAvailableError](LogsNotYetAvailableError/)
- [ModuleLoadError](ModuleLoadError/)
- [NotInTaskContextError](NotInTaskContextError/)
- [OnlyAsyncIOSupportedError](OnlyAsyncIOSupportedError/)
- [OOMError](OOMError/)
- [PrimaryContainerNotFoundError](PrimaryContainerNotFoundError/)
- [ReferenceTaskError](ReferenceTaskError/)
- [RetriesExhaustedError](RetriesExhaustedError/)
- [RunAbortedError](RunAbortedError/)
- [RuntimeDataValidationError](RuntimeDataValidationError/)
- [RuntimeSystemError](RuntimeSystemError/)
- [RuntimeUnknownError](RuntimeUnknownError/)
- [RuntimeUserError](RuntimeUserError/)
- [SlowDownError](SlowDownError/)
- [TaskInterruptedError](TaskInterruptedError/)
- [TaskTimeoutError](TaskTimeoutError/)
- [UnionRpcError](UnionRpcError/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.extend ===
# flyte.extend
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extend > `AsyncFunctionTaskTemplate`** | A task template that wraps an asynchronous functions. |
| **Flyte SDK > Packages > flyte.extend > `ImageBuildEngine`** | ImageBuildEngine contains a list of builders that can be used to build an ImageSpec. |
| **Flyte SDK > Packages > flyte.extend > `TaskTemplate`** | Task template is a template for a task that can be executed. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extend > `download_code_bundle()`** | Downloads the code bundle if it is not already downloaded. |
| **Flyte SDK > Packages > flyte.extend > `get_proto_resources()`** | Get main resources IDL representation from the resources object. |
| **Flyte SDK > Packages > flyte.extend > `is_initialized()`** | Check if the system has been initialized. |
| **Flyte SDK > Packages > flyte.extend > `pod_spec_from_resources()`** | |
### Variables
| Property | Type | Description |
|-|-|-|
| `PRIMARY_CONTAINER_DEFAULT_NAME` | `str` | |
| `TaskPluginRegistry` | `_Registry` | |
## Methods
#### download_code_bundle()
```python
def download_code_bundle(
code_bundle: flyte.models.CodeBundle,
) -> flyte.models.CodeBundle
```
Downloads the code bundle if it is not already downloaded.
| Parameter | Type | Description |
|-|-|-|
| `code_bundle` | `flyte.models.CodeBundle` | The code bundle to download. :return: The code bundle with the downloaded path. |
#### get_proto_resources()
```python
def get_proto_resources(
resources: flyte._resources.Resources | None,
) -> typing.Optional[flyteidl2.core.tasks_pb2.Resources]
```
Get main resources IDL representation from the resources object
| Parameter | Type | Description |
|-|-|-|
| `resources` | `flyte._resources.Resources \| None` | User facing Resources object containing potentially both requests and limits :return: The given resources as requests and limits |
#### is_initialized()
```python
def is_initialized()
```
Check if the system has been initialized.
:return: True if initialized, False otherwise
#### pod_spec_from_resources()
```python
def pod_spec_from_resources(
primary_container_name: str,
requests: typing.Optional[flyte._resources.Resources],
limits: typing.Optional[flyte._resources.Resources],
k8s_gpu_resource_key: str,
) -> V1PodSpec
```
| Parameter | Type | Description |
|-|-|-|
| `primary_container_name` | `str` | |
| `requests` | `typing.Optional[flyte._resources.Resources]` | |
| `limits` | `typing.Optional[flyte._resources.Resources]` | |
| `k8s_gpu_resource_key` | `str` | |
## Subpages
- [AsyncFunctionTaskTemplate](AsyncFunctionTaskTemplate/)
- [ImageBuildEngine](ImageBuildEngine/)
- [TaskTemplate](TaskTemplate/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.extras ===
# flyte.extras
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.extras > `ContainerTask`** | This is an intermediate class that represents Flyte Tasks that run a container at execution time. |
## Subpages
- [ContainerTask](ContainerTask/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.git ===
# flyte.git
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.git > GitStatus** | A class representing the status of a git repository. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.git > `config_from_root()`** | Get the config file from the git root directory. |
## Methods
#### config_from_root()
```python
def config_from_root(
path: pathlib._local.Path | str,
) -> flyte.config._config.Config | None
```
Get the config file from the git root directory.
By default, the config file is expected to be in `.flyte/config.yaml` in the git root directory.
| Parameter | Type | Description |
|-|-|-|
| `path` | `pathlib._local.Path \| str` | Path to the config file relative to git root directory (default :return: Config object if found, None otherwise |
## Subpages
- **Flyte SDK > Packages > flyte.git > GitStatus**
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.git/gitstatus ===
# GitStatus
**Package:** `flyte.git`
A class representing the status of a git repository.
```python
class GitStatus(
is_valid: bool,
is_tree_clean: bool,
remote_url: str,
repo_dir: pathlib._local.Path,
commit_sha: str,
)
```
| Parameter | Type | Description |
|-|-|-|
| `is_valid` | `bool` | Whether git repository is valid |
| `is_tree_clean` | `bool` | Whether working tree is clean |
| `remote_url` | `str` | Remote URL in HTTPS format |
| `repo_dir` | `pathlib._local.Path` | Repository root directory |
| `commit_sha` | `str` | Current commit SHA |
## Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.git > GitStatus > `build_url()`** | Build a git URL for the given path. |
| **Flyte SDK > Packages > flyte.git > GitStatus > `from_current_repo()`** | Discover git information from the current repository. |
### build_url()
```python
def build_url(
path: pathlib._local.Path | str,
line_number: int,
) -> str
```
Build a git URL for the given path.
| Parameter | Type | Description |
|-|-|-|
| `path` | `pathlib._local.Path \| str` | Path to a file |
| `line_number` | `int` | Line number of the code file :return: Path relative to repo_dir |
### from_current_repo()
```python
def from_current_repo()
```
Discover git information from the current repository.
If Git is not installed or .git does not exist, returns GitStatus with is_valid=False.
:return: GitStatus instance with discovered git information
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.io ===
# flyte.io
## IO data types
This package contains additional data types beyond the primitive data types in python to abstract data flow
of large datasets in Union.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io > `DataFrame`** | This is the user facing DataFrame class. |
| **Flyte SDK > Packages > flyte.io > `DataFrameDecoder`** | Helper class that provides a standard way to create an ABC using. |
| **Flyte SDK > Packages > flyte.io > `DataFrameEncoder`** | Helper class that provides a standard way to create an ABC using. |
| **Flyte SDK > Packages > flyte.io > `DataFrameTransformerEngine`** | Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. |
| **Flyte SDK > Packages > flyte.io > `Dir`** | A generic directory class representing a directory with files of a specified format. |
| **Flyte SDK > Packages > flyte.io > `File`** | A generic file class representing a file with a specified format. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.io > `lazy_import_dataframe_handler()`** | |
### Variables
| Property | Type | Description |
|-|-|-|
| `PARQUET` | `str` | |
## Methods
#### lazy_import_dataframe_handler()
```python
def lazy_import_dataframe_handler()
```
## Subpages
- [DataFrame](DataFrame/)
- [DataFrameDecoder](DataFrameDecoder/)
- [DataFrameEncoder](DataFrameEncoder/)
- [DataFrameTransformerEngine](DataFrameTransformerEngine/)
- [Dir](Dir/)
- [File](File/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.models ===
# flyte.models
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > `ActionID`** | A class representing the ID of an Action, nested within a Run. |
| **Flyte SDK > Packages > flyte.models > `Checkpoints`** | A class representing the checkpoints for a task. |
| **Flyte SDK > Packages > flyte.models > `CodeBundle`** | A class representing a code bundle for a task. |
| **Flyte SDK > Packages > flyte.models > `GroupData`** | |
| **Flyte SDK > Packages > flyte.models > `NativeInterface`** | A class representing the native interface for a task. |
| **Flyte SDK > Packages > flyte.models > `PathRewrite`** | Configuration for rewriting paths during input loading. |
| **Flyte SDK > Packages > flyte.models > `RawDataPath`** | A class representing the raw data path for a task. |
| **Flyte SDK > Packages > flyte.models > `SerializationContext`** | This object holds serialization time contextual information, that can be used when serializing the task and. |
| **Flyte SDK > Packages > flyte.models > `TaskContext`** | A context class to hold the current task executions context. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.models > `generate_random_name()`** | Generate a random name for the task. |
### Variables
| Property | Type | Description |
|-|-|-|
| `MAX_INLINE_IO_BYTES` | `int` | |
| `TYPE_CHECKING` | `bool` | |
## Methods
#### generate_random_name()
```python
def generate_random_name()
```
Generate a random name for the task. This is used to create unique names for tasks.
TODO we can use unique-namer in the future, for now its just guids
## Subpages
- [ActionID](ActionID/)
- [Checkpoints](Checkpoints/)
- [CodeBundle](CodeBundle/)
- [GroupData](GroupData/)
- [NativeInterface](NativeInterface/)
- [PathRewrite](PathRewrite/)
- [RawDataPath](RawDataPath/)
- [SerializationContext](SerializationContext/)
- [TaskContext](TaskContext/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.remote ===
# flyte.remote
Remote Entities that are accessible from the Union Server once deployed or created.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > `Action`** | A class representing an action. |
| **Flyte SDK > Packages > flyte.remote > `ActionDetails`** | A class representing an action. |
| **Flyte SDK > Packages > flyte.remote > `ActionInputs`** | A class representing the inputs of an action. |
| **Flyte SDK > Packages > flyte.remote > `ActionOutputs`** | A class representing the outputs of an action. |
| **Flyte SDK > Packages > flyte.remote > `App`** | A mixin class that provides a method to convert an object to a JSON-serializable dictionary. |
| **Flyte SDK > Packages > flyte.remote > `Project`** | A class representing a project in the Union API. |
| **Flyte SDK > Packages > flyte.remote > `Run`** | A class representing a run of a task. |
| **Flyte SDK > Packages > flyte.remote > `RunDetails`** | A class representing a run of a task. |
| **Flyte SDK > Packages > flyte.remote > `Secret`** | |
| **Flyte SDK > Packages > flyte.remote > `Task`** | |
| **Flyte SDK > Packages > flyte.remote > `TaskDetails`** | |
| **Flyte SDK > Packages > flyte.remote > `Trigger`** | |
| **Flyte SDK > Packages > flyte.remote > `User`** | |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.remote > `create_channel()`** | Creates a new gRPC channel with appropriate authentication interceptors. |
| **Flyte SDK > Packages > flyte.remote > `upload_dir()`** | Uploads a directory to a remote location and returns the remote URI. |
| **Flyte SDK > Packages > flyte.remote > `upload_file()`** | Uploads a file to a remote location and returns the remote URI. |
## Methods
#### create_channel()
```python
def create_channel(
endpoint: str | None,
api_key: str | None,
insecure: typing.Optional[bool],
insecure_skip_verify: typing.Optional[bool],
ca_cert_file_path: typing.Optional[str],
ssl_credentials: typing.Optional[ssl_channel_credentials],
grpc_options: typing.Optional[typing.Sequence[typing.Tuple[str, typing.Any]]],
compression: typing.Optional[grpc.Compression],
http_session: httpx.AsyncClient | None,
proxy_command: typing.Optional[typing.List[str]],
kwargs,
) -> grpc.aio._base_channel.Channel
```
Creates a new gRPC channel with appropriate authentication interceptors.
This function creates either a secure or insecure gRPC channel based on the provided parameters,
and adds authentication interceptors to the channel. If SSL credentials are not provided,
they are created based on the insecure_skip_verify and ca_cert_file_path parameters.
The function is async because it may need to read certificate files asynchronously
and create authentication interceptors that perform async operations.
| Parameter | Type | Description |
|-|-|-|
| `endpoint` | `str \| None` | The endpoint URL for the gRPC channel |
| `api_key` | `str \| None` | API key for authentication; if provided, it will be used to detect the endpoint and credentials. |
| `insecure` | `typing.Optional[bool]` | Whether to use an insecure channel (no SSL) |
| `insecure_skip_verify` | `typing.Optional[bool]` | Whether to skip SSL certificate verification |
| `ca_cert_file_path` | `typing.Optional[str]` | Path to CA certificate file for SSL verification |
| `ssl_credentials` | `typing.Optional[ssl_channel_credentials]` | Pre-configured SSL credentials for the channel |
| `grpc_options` | `typing.Optional[typing.Sequence[typing.Tuple[str, typing.Any]]]` | Additional gRPC channel options |
| `compression` | `typing.Optional[grpc.Compression]` | Compression method for the channel |
| `http_session` | `httpx.AsyncClient \| None` | Pre-configured HTTP session to use for requests |
| `proxy_command` | `typing.Optional[typing.List[str]]` | List of strings for proxy command configuration |
| `kwargs` | `**kwargs` | Additional arguments passed to various functions - For grpc.aio.insecure_channel/secure_channel: - root_certificates: Root certificates for SSL credentials - private_key: Private key for SSL credentials - certificate_chain: Certificate chain for SSL credentials - options: gRPC channel options - compression: gRPC compression method - For proxy configuration: - proxy_env: Dict of environment variables for proxy - proxy_timeout: Timeout for proxy connection - For authentication interceptors (passed to create_auth_interceptors and create_proxy_auth_interceptors): - auth_type: The authentication type to use ("Pkce", "ClientSecret", "ExternalCommand", "DeviceFlow") - command: Command to execute for ExternalCommand authentication - client_id: Client ID for ClientSecret authentication - client_secret: Client secret for ClientSecret authentication - client_credentials_secret: Client secret for ClientSecret authentication (alias) - scopes: List of scopes to request during authentication - audience: Audience for the token - http_proxy_url: HTTP proxy URL - verify: Whether to verify SSL certificates - ca_cert_path: Optional path to CA certificate file - header_key: Header key to use for authentication - redirect_uri: OAuth2 redirect URI for PKCE authentication - add_request_auth_code_params_to_request_access_token_params: Whether to add auth code params to token request - request_auth_code_params: Parameters to add to login URI opened in browser - request_access_token_params: Parameters to add when exchanging auth code for access token - refresh_access_token_params: Parameters to add when refreshing access token :return: grpc.aio.Channel with authentication interceptors configured |
#### upload_dir()
```python
def upload_dir(
dir_path: pathlib._local.Path,
verify: bool,
) -> str
```
Uploads a directory to a remote location and returns the remote URI.
| Parameter | Type | Description |
|-|-|-|
| `dir_path` | `pathlib._local.Path` | The directory path to upload. |
| `verify` | `bool` | Whether to verify the certificate for HTTPS requests. :return: The remote URI of the uploaded directory. |
#### upload_file()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await upload_file.aio()`.
```python
def upload_file(
fp: pathlib._local.Path,
verify: bool,
) -> typing.Tuple[str, str]
```
Uploads a file to a remote location and returns the remote URI.
| Parameter | Type | Description |
|-|-|-|
| `fp` | `pathlib._local.Path` | The file path to upload. |
| `verify` | `bool` | Whether to verify the certificate for HTTPS requests. :return: A tuple containing the MD5 digest and the remote URI. |
## Subpages
- [Action](Action/)
- [ActionDetails](ActionDetails/)
- [ActionInputs](ActionInputs/)
- [ActionOutputs](ActionOutputs/)
- [App](App/)
- [Project](Project/)
- [Run](Run/)
- [RunDetails](RunDetails/)
- [Secret](Secret/)
- [Task](Task/)
- [TaskDetails](TaskDetails/)
- [Trigger](Trigger/)
- [User](User/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.report ===
# flyte.report
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.report > `Report`** | |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.report > `current_report()`** | Get the current report. |
| **Flyte SDK > Packages > flyte.report > Methods > flush()** | Flush the report. |
| **Flyte SDK > Packages > flyte.report > `get_tab()`** | Get a tab by name. |
| **Flyte SDK > Packages > flyte.report > Methods > log()** | Log content to the main tab. |
| **Flyte SDK > Packages > flyte.report > Methods > replace()** | Get the report. |
## Methods
#### current_report()
```python
def current_report()
```
Get the current report. This is a dummy report if not in a task context.
:return: The current report.
#### flush()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await flush.aio()`.
```python
def flush()
```
Flush the report.
#### get_tab()
```python
def get_tab(
name: str,
create_if_missing: bool,
) -> flyte.report._report.Tab
```
Get a tab by name. If the tab does not exist, create it.
| Parameter | Type | Description |
|-|-|-|
| `name` | `str` | The name of the tab. |
| `create_if_missing` | `bool` | Whether to create the tab if it does not exist. :return: The tab. |
#### log()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await log.aio()`.
```python
def log(
content: str,
do_flush: bool,
)
```
Log content to the main tab. The content should be a valid HTML string, but not a complete HTML document,
as it will be inserted into a div.
| Parameter | Type | Description |
|-|-|-|
| `content` | `str` | The content to log. |
| `do_flush` | `bool` | flush the report after logging. |
#### replace()
> [!NOTE] This method can be called both synchronously or asynchronously.
> Default invocation is sync and will block.
> To call it asynchronously, use the function `.aio()` on the method name itself, e.g.,:
> `result = await replace.aio()`.
```python
def replace(
content: str,
do_flush: bool,
)
```
Get the report. Replaces the content of the main tab.
:return: The report.
| Parameter | Type | Description |
|-|-|-|
| `content` | `str` | |
| `do_flush` | `bool` | |
## Subpages
- [Report](Report/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.storage ===
# flyte.storage
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.storage > `ABFS`** | Any Azure Blob Storage specific configuration. |
| **Flyte SDK > Packages > flyte.storage > `GCS`** | Any GCS specific configuration. |
| **Flyte SDK > Packages > flyte.storage > `S3`** | S3 specific configuration. |
| **Flyte SDK > Packages > flyte.storage > `Storage`** | Data storage configuration that applies across any provider. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.storage > Methods > exists()** | Check if a path exists. |
| **Flyte SDK > Packages > flyte.storage > `exists_sync()`** | |
| **Flyte SDK > Packages > flyte.storage > Methods > get()** | |
| **Flyte SDK > Packages > flyte.storage > `get_configured_fsspec_kwargs()`** | |
| **Flyte SDK > Packages > flyte.storage > `get_random_local_directory()`** | :return: a random directory. |
| **Flyte SDK > Packages > flyte.storage > `get_random_local_path()`** | Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name. |
| **Flyte SDK > Packages > flyte.storage > `get_stream()`** | Get a stream of data from a remote location. |
| **Flyte SDK > Packages > flyte.storage > `get_underlying_filesystem()`** | |
| **Flyte SDK > Packages > flyte.storage > `is_remote()`** | Let's find a replacement. |
| **Flyte SDK > Packages > flyte.storage > Methods > join()** | Join multiple paths together. |
| **Flyte SDK > Packages > flyte.storage > open()** | Asynchronously open a file and return an async context manager. |
| **Flyte SDK > Packages > flyte.storage > put()** | |
| **Flyte SDK > Packages > flyte.storage > `put_stream()`** | Put a stream of data to a remote location. |
## Methods
#### exists()
```python
def exists(
path: str,
kwargs,
) -> bool
```
Check if a path exists.
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | Path to be checked. |
| `kwargs` | `**kwargs` | Additional arguments to be passed to the underlying filesystem. :return: True if the path exists, False otherwise. |
#### exists_sync()
```python
def exists_sync(
path: str,
kwargs,
) -> bool
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | |
| `kwargs` | `**kwargs` | |
#### get()
```python
def get(
from_path: str,
to_path: Optional[str | pathlib.Path],
recursive: bool,
kwargs,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `from_path` | `str` | |
| `to_path` | `Optional[str \| pathlib.Path]` | |
| `recursive` | `bool` | |
| `kwargs` | `**kwargs` | |
#### get_configured_fsspec_kwargs()
```python
def get_configured_fsspec_kwargs(
protocol: typing.Optional[str],
anonymous: bool,
) -> typing.Dict[str, typing.Any]
```
| Parameter | Type | Description |
|-|-|-|
| `protocol` | `typing.Optional[str]` | |
| `anonymous` | `bool` | |
#### get_random_local_directory()
```python
def get_random_local_directory()
```
:return: a random directory
:rtype: pathlib.Path
#### get_random_local_path()
```python
def get_random_local_path(
file_path_or_file_name: pathlib.Path | str | None,
) -> pathlib.Path
```
Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name
| Parameter | Type | Description |
|-|-|-|
| `file_path_or_file_name` | `pathlib.Path \| str \| None` | |
#### get_stream()
```python
def get_stream(
path: str,
chunk_size,
kwargs,
) -> AsyncGenerator[bytes, None]
```
Get a stream of data from a remote location.
This is useful for downloading streaming data from a remote location.
Example usage:
```python
import flyte.storage as storage
async for chunk in storage.get_stream(path="s3://my_bucket/my_file.txt"):
process(chunk)
```
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | Path to the remote location where the data will be downloaded. |
| `chunk_size` | | Size of each chunk to be read from the file. :return: An async iterator that yields chunks of bytes. |
| `kwargs` | `**kwargs` | Additional arguments to be passed to the underlying filesystem. |
#### get_underlying_filesystem()
```python
def get_underlying_filesystem(
protocol: typing.Optional[str],
anonymous: bool,
path: typing.Optional[str],
kwargs,
) -> fsspec.AbstractFileSystem
```
| Parameter | Type | Description |
|-|-|-|
| `protocol` | `typing.Optional[str]` | |
| `anonymous` | `bool` | |
| `path` | `typing.Optional[str]` | |
| `kwargs` | `**kwargs` | |
#### is_remote()
```python
def is_remote(
path: typing.Union[pathlib.Path | str],
) -> bool
```
Let's find a replacement
| Parameter | Type | Description |
|-|-|-|
| `path` | `typing.Union[pathlib.Path \| str]` | |
#### join()
```python
def join(
paths: str,
) -> str
```
Join multiple paths together. This is a wrapper around os.path.join.
# TODO replace with proper join with fsspec root etc
| Parameter | Type | Description |
|-|-|-|
| `paths` | `str` | Paths to be joined. |
#### open()
```python
def open(
path: str,
mode: str,
kwargs,
) -> AsyncReadableFile | AsyncWritableFile
```
Asynchronously open a file and return an async context manager.
This function checks if the underlying filesystem supports obstore bypass.
If it does, it uses obstore to open the file. Otherwise, it falls back to
the standard _open function which uses AsyncFileSystem.
It will raise NotImplementedError if neither obstore nor AsyncFileSystem is supported.
| Parameter | Type | Description |
|-|-|-|
| `path` | `str` | |
| `mode` | `str` | |
| `kwargs` | `**kwargs` | |
#### put()
```python
def put(
from_path: str,
to_path: Optional[str],
recursive: bool,
kwargs,
) -> str
```
| Parameter | Type | Description |
|-|-|-|
| `from_path` | `str` | |
| `to_path` | `Optional[str]` | |
| `recursive` | `bool` | |
| `kwargs` | `**kwargs` | |
#### put_stream()
```python
def put_stream(
data_iterable: typing.AsyncIterable[bytes] | bytes,
name: str | None,
to_path: str | None,
kwargs,
) -> str
```
Put a stream of data to a remote location. This is useful for streaming data to a remote location.
Example usage:
```python
import flyte.storage as storage
storage.put_stream(iter([b'hello']), name="my_file.txt")
OR
storage.put_stream(iter([b'hello']), to_path="s3://my_bucket/my_file.txt")
```
| Parameter | Type | Description |
|-|-|-|
| `data_iterable` | `typing.AsyncIterable[bytes] \| bytes` | Iterable of bytes to be streamed. |
| `name` | `str \| None` | Name of the file to be created. If not provided, a random name will be generated. |
| `to_path` | `str \| None` | Path to the remote location where the data will be stored. |
| `kwargs` | `**kwargs` | Additional arguments to be passed to the underlying filesystem. :rtype: str :return: The path to the remote location where the data was stored. |
## Subpages
- [ABFS](ABFS/)
- [GCS](GCS/)
- [S3](S3/)
- [Storage](Storage/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.syncify ===
# flyte.syncify
# Syncify Module
This module provides the `syncify` decorator and the `Syncify` class.
The decorator can be used to convert asynchronous functions or methods into synchronous ones.
This is useful for integrating async code into synchronous contexts.
Every asynchronous function or method wrapped with `syncify` can be called synchronously using the
parenthesis `()` operator, or asynchronously using the `.aio()` method.
Example::
```python
from flyte.syncify import syncify
@syncify
async def async_function(x: str) -> str:
return f"Hello, Async World {x}!"
# now you can call it synchronously
result = async_function("Async World") # Note: no .aio() needed for sync calls
print(result)
# Output: Hello, Async World Async World!
# or call it asynchronously
async def main():
result = await async_function.aio("World") # Note the use of .aio() for async calls
print(result)
```
## Creating a Syncify Instance
```python
from flyte.syncify. import Syncify
syncer = Syncify("my_syncer")
# Now you can use `syncer` to decorate your async functions or methods
```
## How does it work?
The Syncify class wraps asynchronous functions, classmethods, instance methods, and static methods to
provide a synchronous interface. The wrapped methods are always executed in the context of a background loop,
whether they are called synchronously or asynchronously. This allows for seamless integration of async code, as
certain async libraries capture the event loop. An example is grpc.aio, which captures the event loop.
In such a case, the Syncify class ensures that the async function is executed in the context of the background loop.
To use it correctly with grpc.aio, you should wrap every grpc.aio channel creation, and client invocation
with the same `Syncify` instance. This ensures that the async code runs in the correct event loop context.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.syncify > `Syncify`** | A decorator to convert asynchronous functions or methods into synchronous ones. |
## Subpages
- [Syncify](Syncify/)
=== PAGE: https://www.union.ai/docs/v2/byoc/api-reference/flyte-sdk/packages/flyte.types ===
# flyte.types
# Flyte Type System
The Flyte type system provides a way to define, transform, and manipulate types in Flyte workflows.
Since the data flowing through Flyte has to often cross process, container and langauge boundaries, the type system
is designed to be serializable to a universal format that can be understood across different environments. This
universal format is based on Protocol Buffers. The types are called LiteralTypes and the runtime
representation of data is called Literals.
The type system includes:
- **TypeEngine**: The core engine that manages type transformations and serialization. This is the main entry point for
for all the internal type transformations and serialization logic.
- **TypeTransformer**: A class that defines how to transform one type to another. This is extensible
allowing users to define custom types and transformations.
- **Renderable**: An interface for types that can be rendered as HTML, that can be outputted to a flyte.report.
It is always possible to bypass the type system and use the `FlytePickle` type to serialize any python object
into a pickle format. The pickle format is not human-readable, but can be passed between flyte tasks that are
written in python. The Pickled objects cannot be represented in the UI, and may be in-efficient for large datasets.
## Directory
### Classes
| Class | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > `FlytePickle`** | This type is only used by flytekit internally. |
| **Flyte SDK > Packages > flyte.types > `TypeEngine`** | Core Extensible TypeEngine of Flytekit. |
| **Flyte SDK > Packages > flyte.types > `TypeTransformer`** | Base transformer type that should be implemented for every python native type that can be handled by flytekit. |
### Protocols
| Protocol | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > `Renderable`** | Base class for protocol classes. |
### Errors
| Exception | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > `TypeTransformerFailedError`** | Inappropriate argument type. |
### Methods
| Method | Description |
|-|-|
| **Flyte SDK > Packages > flyte.types > `guess_interface()`** | Returns the interface of the task with guessed types, as types may not be present in current env. |
| **Flyte SDK > Packages > flyte.types > `literal_string_repr()`** | This method is used to convert a literal map to a string representation. |
## Methods
#### guess_interface()
```python
def guess_interface(
interface: flyteidl2.core.interface_pb2.TypedInterface,
default_inputs: typing.Optional[typing.Iterable[flyteidl2.task.common_pb2.NamedParameter]],
) -> flyte.models.NativeInterface
```
Returns the interface of the task with guessed types, as types may not be present in current env.
| Parameter | Type | Description |
|-|-|-|
| `interface` | `flyteidl2.core.interface_pb2.TypedInterface` | |
| `default_inputs` | `typing.Optional[typing.Iterable[flyteidl2.task.common_pb2.NamedParameter]]` | |
#### literal_string_repr()
```python
def literal_string_repr(
lm: typing.Union[flyteidl2.core.literals_pb2.Literal, flyteidl2.task.common_pb2.NamedLiteral, flyteidl2.task.common_pb2.Inputs, flyteidl2.task.common_pb2.Outputs, flyteidl2.core.literals_pb2.LiteralMap, typing.Dict[str, flyteidl2.core.literals_pb2.Literal]],
) -> typing.Dict[str, typing.Any]
```
This method is used to convert a literal map to a string representation.
| Parameter | Type | Description |
|-|-|-|
| `lm` | `typing.Union[flyteidl2.core.literals_pb2.Literal, flyteidl2.task.common_pb2.NamedLiteral, flyteidl2.task.common_pb2.Inputs, flyteidl2.task.common_pb2.Outputs, flyteidl2.core.literals_pb2.LiteralMap, typing.Dict[str, flyteidl2.core.literals_pb2.Literal]]` | |
## Subpages
- [FlytePickle](FlytePickle/)
- [Renderable](Renderable/)
- [TypeEngine](TypeEngine/)
- [TypeTransformer](TypeTransformer/)
- [TypeTransformerFailedError](TypeTransformerFailedError/)
=== PAGE: https://www.union.ai/docs/v2/byoc/community ===
# Community
Union.ai is a commercial product built on top of the open source Flyte project.
Since the success of Flyte is essential to the success of Union.ai,
the company is dedicated to building and expanding the Flyte open source project and community.
For information on how to get involved and how to keep in touch, see the [Flyte variant of this page](/docs/v2/flyte//community).
## Contributing to documentation
Union AI maintains and hosts both Flyte and Union documentation at [www.union.ai/docs](/docs/v2/root/).
The two sets of documentation are deeply integrated, as the Union product is built on top of Flyte.
To better maintain both sets of docs, they are hosted in the same repository (but rendered so that you can choose to view one or the other).
Both the Flyte and Union documentation are open source.
Flyte community members and Union customers are both welcome to contribute to the documentation.
If you are interested, see [Contributing documentation and examples](./contributing-docs/_index).
## Subpages
- **Contributing docs and examples**
=== PAGE: https://www.union.ai/docs/v2/byoc/community/contributing-docs ===
# Contributing docs and examples
We welcome contributions to the docs and examples for both Flyte and Union.
This section will explain how the docs site works, how to author and build it locally, and how to publish your changes.
## The combined Flyte and Union docs site
As the primary maintainer and contributor of the open-source Flyte project, Union AI is responsible for hosting the Flyte documentation.
Additionally, Union AI is also the company behind the commercial Union.ai product, which is based on Flyte.
Since Flyte and Union.ai share a lot of common functionality, much of the documentation content is common between the two.
However, there are some significant differences between not only Flyte and Union.ai but also among the different Union.ai product offering (Serverless, BYOC, and Self-managed).
To effectively and efficiently maintain the documentation for all of these variants, we employ a single-source-of-truth approach where:
* All content is stored in a single GitHub repository, [`unionai/unionai-docs`](https://github.com/unionai/unionai-docs)
* All content is published on a single website, [`www.union.ai/docs`](/docs/v2/root/).
* The website has a variant selector at the top of the page that lets you choose which variant you want to view:
* Flyte OSS
* Union Serverless
* Union BYOC
* Union Self-managed
* There is also version selector. Currently two versions are available:
* v1 (the original docs for Flyte/Union 1.x)
* v2 (the new docs for Flyte/Union 2.0, which is the one you are currently viewing)
## Versions
The two versions of the docs are stored in separate branches of the GitHub repository:
* [`v1` branch](https://github.com/unionai/unionai-docs/tree/v1) for the v1 docs.
* [`main` branch](https://github.com/unionai/unionai-docs) for the v2 docs.
See **Contributing docs and examples > Versions** for more details.
## Variants
Within each branch the multiple variants are supported by using conditional rendering:
* Each page of content has a `variants` front matter field that specifies which variants the page is applicable to.
* Within each page, rendering logic can be used to include or exclude content based on the selected variant.
The result is that:
* Content that is common to all variants is authored and stored once.
There is no need to keep multiple copies of the same content in-sync.
* Content specific to a variant is conditionally rendered based on the selected variant.
See **Contributing docs and examples > Variants** for more details.
## Both Flyte and Union docs are open source
Since the docs are now combined in one repository, and the Flyte docs are open source, the Union docs are also open source.
All the docs are available for anyone to contribute to: Flyte contributors, Union customers, and Union employees.
If you are a Flyte contributor, you will be contributing docs related to Flyte features and functionality, but in many cases these features and functionality will also be available in Union.
Because the docs site is a single source for all the documentation, when you make changes related to Flyte that are also valid for Union you do so in the same place.
This is by design and is a key feature of the docs site.
## Subpages
- **Contributing docs and examples > Quick start**
- **Contributing docs and examples > Variants**
- **Contributing docs and examples > Versions**
- **Contributing docs and examples > Authoring**
- **Contributing docs and examples > Shortcodes**
- **Contributing docs and examples > Redirects**
- **Contributing docs and examples > API docs**
- **Contributing docs and examples > Publishing**
=== PAGE: https://www.union.ai/docs/v2/byoc/community/contributing-docs/quick-start ===
# Quick start
## Prerequisites
The docs site is built using the [Hugo](https://gohugo.io/) static site generator.
You will need to install it to build the site locally.
See [Hugo Installation](https://gohugo.io/getting-started/installing/).
## Clone the repository
Clone the [`unionai/docs`](https://github.com/unionai/unionai-docs) repository to your local machine.
The content is located in the `content/` folder in the form of Markdown files.
The hierarchy of the files and folders under `content/` directly reflect the URL and navigation structure of the site.
## Live preview
Next, set up the live preview by going to the root of your local repository check-out and copy `hugo.local.toml~sample` to `hugo.local.toml`:
```shell
$ cp hugo.local.toml~sample hugo.local.toml
```
This file contains the configuration for the live preview:
By default, it is set to display the `flyte` variant of the docs site along with enabling the flags `show_inactive`, `highlight_active`, and `highlight_keys` (more about these below)
Now you can start the live preview server by running:
```shell
$ make dev
```
This will build the site and launch a local server at `http://localhost:1313`.
Go to that URL to the live preview. Leave the server running.
As you edit the content you will see the changes reflected in the live preview.
## Distribution build
To build the site for distribution, run:
```shell
$ make dist
```
This will build the site locally just as it is built by the Cloudflare CI for production.
You can view the result of the build by running a local server:
```shell
$ make serve
```
This will start a local server at `http://localhost:9000` and serve the contents of the `dist/` folder. You can also specify a port number:
```shell
$ make serve PORT=
```
=== PAGE: https://www.union.ai/docs/v2/byoc/community/contributing-docs/variants ===
# Variants
The docs site supports the ability to show or hide content based of the current variant selection.
There are separate mechanisms for:
* Including or excluding entire pages based on the selected variant.
* Conditional rendering of content within a page based on the selected variant using an if-then-like construct.
* Rendering keywords as variables that change based on the selected variant.
Currently, the docs site supports four variants:
- **Flyte OSS**: The open-source Flyte project.
- **Serverless**: The Union.ai product that is hosted and managed by Union AI.
- **BYOC**: The Union.ai product that is hosted on the customer's infrastructure but managed by Union AI.
- **Self-managed**: The Union.ai product that is hosted and managed by the customer.
Each variant is referenced in the page logic using its respective code name: `flyte`, `serverless`, `byoc`, or `selfmanaged`.
The available set of variants are defined in the `config..toml` files in the root of the repository.
## Variants at the whole-page level
The docs site supports the ability to show or hide entire pages based of the selected variant.
Not all pages are available in all variants because features differ across the variants.
In the public website, if you are on page in one variant, and you change to a different variant, the page will change to the same page in the new variant *if it exists*.
If it does not exist, you will see a message indicating that the page is not available in the selected variant.
In the source Markdown, the presence or absence of a page in a given variant is governed by `variants` field in the front matter parameter of the page.
For example, if you look at the Markdown source for [this page (the page you are currently viewing)](https://github.com/unionai/docs/content/community/contributing-docs.md), you will see the following front matter:
```markdown
---
title: Platform overview
weight: 1
variants: +flyte +serverless +byoc +selfmanaged
---
```
The `variants` field has the value:
`+flyte +serverless +byoc +selfmanaged`
The `+` indicates that the page is available for the specified variant.
In this case, the page is available for all four variants.
If you wanted to make the page available for only the `flyte` and `serverless` variants, you would change the `variants` field to:
`+flyte +serverless -byoc -selfmanaged`
In [live preview mode](./authoring-core-content#live-preview) with the `show_inactive` flag enabled, you will see all pages in the navigation tree, with the ones unavailable for the current variant grayed out.
As you can see, the `variants` field expects a space-separated list of keywords:
* The code names for the currently variants are, `flyte`, `serverless`, `byoc`, and `selfmanaged`.
* All supported variants must be included explicitly in every `variants` field with a leading `+` or `-`. There is no default behavior.
* The supported variants are configured in the root of the repository in the files named `config..toml`.
## Conditional rendering within a page
Content can also differ *within a page* based on the selected variant.
This is done with conditional rendering using the `{{* variant */>}}` and `{{* key */>}}` [Hugo shortcodes](https://gohugo.io/content-management/shortcodes/).
### {{* variant */>}}
The syntax for the `{{* variant */>}}` shortcode is:
```markdown
{{* variant */>}}
...
{{* /variant */>}}
```
Where `` is a list the code name for the variants you want to show the content for.
Note that the variant construct can only directly contain other shortcode constructs, not plain Markdown.
In the most common case, you will want to use the `{{* markdown */>}}` shortcode (which can contain Markdown) inside the `{{* variant */>}}` shortcode to render Markdown content, like this:
```markdown
{{* variant serverless byoc */>}}
{{* markdown */>}}
This content is only visible in the `serverless` and `byoc` variants.
{{* /markdown */>}}
{{* button-link text="Contact Us" target="https://union.ai/contact" */>}}
{{* /variant */>}}
```
For more details on the `{{* variant */>}}` shortcode, see the **Contributing docs and examples > Shortcodes > Component Library > `{{* variant */>}}`**.
### {{* key */>}}
The syntax for the `{{* key */>}}` shortcode is:
```markdown
{{* key */>}}
```
Where `` is the name of the key you want to render.
For example, if you want to render the product name keyword, you would use:
```markdown
{{* key product_name */>}}
```
The available key names are defined in the [params.key] section of the `hugo.site.toml` configuration file in the root of the repository.
For example the `product_name` used above is defined in that file as
```toml
[params.key.product_name]
flyte = "Flyte"
serverless = "Union.ai"
byoc = "Union.ai"
selfmanaged = "Union.ai"
```
Meaning that in any content that appears in the `flyte` variant of the site `{{* key product_name */>}}` shortcode will be replaced with `Flyte`, and in any content that appears in the `serverless`, `byoc`, or `selfmanaged` variants, it will be replaced with `Union.ai`.
For more details on the `{{* key */>}}` shortcode, see the **Contributing docs and examples > Shortcodes > Component Library > `{{* key */>}}`**
## Full example
Here is full example. If you look at the Markdown source for [this page (the page you are currently viewing)](https://github.com/unionai/docs/content/community/contributing-docs/variants.md), you will see the following section:
```markdown
> **This text is visible in all variants.**
>
> {{* variant flyte */>}}
> {{* markdown */>}}
>
> **This text is only visible in the `flyte` variant.**
>
> {{* /markdown */>}}
> {{* /variant */>}}
> {{* variant serverless byoc selfmanaged */>}}
> {{* markdown */>}}
>
> **This text is only visible in the `serverless`, `byoc`, and `selfmanaged` variants.**
>
> {{* /markdown */>}}
> {{* /variant */>}}
>
> **Below is a `{{* key product_full_name */>}}` shortcode.
> It will be replaced with the current variant's full name:**
>
> **{{* key product_full_name */>}}**
```
This Markdown source is rendered as:
> **This text is visible in all variants.**
>
>
> >
>
> **This text is only visible in the `serverless`, `byoc`, and `selfmanaged` variants.**
>
>
>
>
> **Below is a `{{* key product_full_name */>}}` shortcode.
> It will be replaced with the current variant's full name:**
>
> **Union.ai BYOC**
If you switch between variants with the variant selector at the top of the page, you will see the content change accordingly.
## Adding a new variant
A variant is a term we use to identify a product or major section of the site.
Such variant has a dedicated token that identifies it, and all resources are
tagged to be either included or excluded when the variant is built.
> Adding new variants is a rare event and must be reserved when new products
> or major developments.
>
> If you are thinking adding a new variant is the way
> to go, please double-check with the infra admin to confirm before doing all
> the work below and waste your time.
### Location
When deploying, the variant takes a folder in the root
`https:////`
For example, if we have a variant `acme`, then when built the content goes to:
`https:///acme/`
### Creating a new variant
To create a new variant a few steps are required:
| File | Changes |
| ----------------------- | -------------------------------------------------------------- |
| `hugo.site.toml` | Add to `params.variant_weights` and all `params.key` |
| `hugo.toml` | Add to `params.search` |
| `Makefile` | Add a new `make variant` to `dist` target |
| `.md` | Add either `+` or `-` to all content pages |
| `config..toml` | Create a new file and configure `baseURL` and `params.variant` |
### Testing the new variant
As you develop the new variant, it is recommended to have a `pre-release/` semi-stable
branch to confirm everything is working and the content looks good. It will also allow others
to collaborate by creating PRs against it (`base=pre-release/` instead of `main`)
without trampling on each other and allowing for parallel reviews.
Once the variant branch is correct, you merge that branch into main.
### Building (just) the variant
You can build the production version of the variant,
which will also trigger all the safety checks as well,
by invoking the variant build:
```shell
$ make variant VARIANT=
```
For example:
```shell
make variant VARIANT=serverless
```
=== PAGE: https://www.union.ai/docs/v2/byoc/community/contributing-docs/versions ===
# Versions
In addition to the product variants, the docs site also supports multiple versions of the documentation.
The version selector is located at the top of the page, next to the variant selector.
Versions and variants are independent of each other, with the version being "above" the variant in the URL hierarchy.
The URL for version `v2` of the current page (the one you are one right now) in the Flyte variant is:
`/docs/v2/flyte//community/contributing-docs/versions`
while the URL for version `v1` of the same page is:
`/docs/v1/flyte//community/contributing-docs/versions`
### Versions are branches
The versioning system is based on long-lived Git branches in the `unionai/unionai-docs` GitHub repository:
- The `main` branch contains the latest version of the documentation. Currently, `v2`.
- Other versions of the docs are contained in branches named `vX`, where `X` is the major version number. Currently, there is one other version, `v1`.
## How to create an archive version
An "archive version" is a static snapshot of the site at a given point in time.
It is meant to freeze a specific version of the site for historical purposes,
such as preserving the content and structure of the site at a specific point in time.
### How to create an archive version
1. Create a new branch from `main` named `vX`, e.g. `v3`.
2. Add the version to the `VERSION` field in the `makefile.inc` file, e.g. `VERSION := v3`.
3. Add the version to the `versions` field in the `hugo.ver.toml` file, e.g. `versions = [ "v1", "v2", "v3" ]`.
> [!NOTE]
> **Important:** You must update the `versions` field in **ALL** published and archived versions of the site.
### Publishing an archive version
> [!NOTE]
> This step can only be done by a Union employee.
1. Update the `docs_archive_versions` in the `docs_archive_locals.tf` Terraform file
2. Create a PR for the changes
3. Once the PR is merged, run the production pipeline to activate the new version
=== PAGE: https://www.union.ai/docs/v2/byoc/community/contributing-docs/authoring ===
# Authoring
## Getting started
Content is located in the `content` folder.
To create a new page, simply create a new Markdown file in the appropriate folder and start writing it!
## Target the right branch
Remember that there are two production branches in the docs: `main` and `v1`.
* **For Flyte or Union 1, create a branch off of `v1` and target your pull request to `v1`**
* **For Flyte or Union 2, create a branch off of `main` and target your pull request to `main`**
## Live preview
While editing, you can use Hugo's local live preview capabilities.
Simply execute
```shell
$ make dev
```
This will build the site and launch a local server at `http://localhost:1313`.
Go to that URL to the live preview. Leave the server running.
As you edit the preview will update automatically.
See **Contributing docs and examples > Publishing** for how to set up your machine.
## Pull Requests + Site Preview
Pull requests will create a preview build of the site on CloudFlare.
Check the pull request for a dynamic link to the site changes within that PR.
## Page Visibility
This site uses variants, which means different "flavors" of the content.
For a given -age, its variant visibility is governed by the `variants:` field in the front matter of the page source.
For each variant you specify `+` to include or `-` to exclude it.
For example:
```markdown
---
title: My Page
variants: -flyte +serverless +byoc -selfmanaged
---
```
In this example the page will be:
* Included in Serverless and BYOC.
* Excluded from Flyte and Self-managed.
> [!NOTE]
> All variants must be explicitly listed in the `variants` field.
> This helps avoid missing or extraneous pages.
## Page order
Pages are ordered by the value of `weight` field (an integer >= 0) in the frontmatter of the page,
1. The higher the weight the lower the page sits in navigation ordering among its peers in the same folder.
2. Pages with no weight field (or `weight = 0`) will be ordered last.
3. Pages of the same weight will be sorted alphabetically by their title.
4. Folders are ordered among their peers (other folders and pages at the same level of the hierarchy) by the weight of their `_index.md` page.
For example:
```markdown
---
title: My Page
weight: 3
---
```
## Page settings
| Setting | Type | Description |
| ------------------ | ---- | --------------------------------------------------------------------------------- |
| `top_menu` | bool | If `true` the item becomes a tab at the top and its hierarchy goes to the sidebar |
| `sidebar_expanded` | bool | If `true` the section becomes expanded in the sidebar. Permanently. |
| `site_root` | bool | If `true` indicates that the page is the site landing page |
| `toc_max` | int | Maximum heading to incorporate in the right navigation table of contents. |
## Conditional Content
The site has "flavors" of the documentation. We leverage the `{{* variant */>}}` tag to control
which content is rendered on which flavor.
Refer to **Contributing docs and examples > Shortcodes > Variants** for detailed explanation.
## Warnings and Notices
You can write regular Markdown and use the notation below to create information and warning boxes:
```markdown
> [!NOTE] This is the note title
> You write the note content here. It can be
> anything you want.
```
Or if you want a warning:
```markdown
> [!WARNING] This is the title of the warning
> And here you write what you want to warn about.
```
## Special Content Generation
There are various short codes to generate content or special components (tabs, dropdowns, etc.)
Refer to **Contributing docs and examples > Shortcodes** for more information.
## Python Generated Content
You can generate pages from markdown-commented Python files.
At the top of your `.md` file, add:
```markdown
---
layout: py_example
example_file: /path/to/your/file.py
run_command: union run --remote tutorials//path/to/your/file.py main
source_location: https://www.github.com/unionai/unionai-examples/tree/main/tutorials/path/to/your/file.py
---
```
Where the referenced file looks like this:
```python
# # Credit Default Prediction with XGBoost & NVIDIA RAPIDS
#
# In this tutorial, we will use NVIDIA RAPIDS `cudf` DataFrame library for preprocessing
# data and XGBoost, an optimized gradient boosting library, for credit default prediction.
# We'll learn how to declare NVIDIA `A100` for our training function and `ImageSpec`
# for specifying our python dependencies.
# {{run-on-union}}
# ## Declaring workflow dependencies
#
# First, we start by importing all the dependencies that is required by this workflow:
import os
import gc
from pathlib import Path
from typing import Tuple
import fsspec
from flytekit import task, workflow, current_context, Resources, ImageSpec, Deck
from flytekit.types.file import FlyteFile
from flytekit.extras.accelerators import A100
```
Note that the text content is embedded in comments as Markdown, and the code is normal python code.
The generator will convert the markdown into normal page text content and the code into code blocks within that Markdown content.
### Run on Union Instructions
We can add the run on Union instructions anywhere in the content.
Annotate the location you want to include it with `{{run-on-union}}`. Like this:
```markdown
# The quick brown fox wants to see the Union instructions.
#
# {{run-on-union}}
#
# And it shall have it.
```
The resulting **Run on Union** section in the rendered docs will include the run command and source location,
specified as `run_command` and `source_location` in the front matter of the corresponding `.md` page.
## Jupyter Notebooks
You can also generate pages from Jupyter notebooks.
At the top of your.md file, add:
---
jupyter_notebook: /path/to/your/notebook.ipynb
---
Then run the `Makefile.jupyter` target to generate the page.
```shell
$ make -f Makefile.jupyter
```
> [!NOTE]
> You must `uv sync` and activate the environment in `tools/jupyter_generator` before running the
> `Makefile.jupyter` target, or make sure all the necessary dependencies are installed for yourself.
**Committing the change:** When the PR is pushed, a check for consistency between the notebook and its source will run. Please ensure that if you change the notebook, you re-run the `Makefile.jupyter` target to update the page.
## Mapped Keys (`{{* key */>}}`)
Key is a very special command that allows us to define mapped values to a variant.
For example, the product name changes if it is Flyte, Union BYOC, etc. For that,
we can define a single key `product_full_name` and map it to reflect automatically,
without the need to `if variant` around it.
Please refer to **Contributing docs and examples > Authoring > {{* key */>}} shortcode** for more details.
## Mermaid Graphs
To embed Mermaid diagrams in a page, insert the code inside a block like this:
```mermaid
your mermaid graph goes here
```
Also add `mermaid: true` to the top of your page to enable rendering.
> [!NOTE]
> You can use [Mermaid's playground](https://www.mermaidchart.com/play) to design diagrams and get the code
=== PAGE: https://www.union.ai/docs/v2/byoc/community/contributing-docs/shortcodes ===
# Shortcodes
This site has special blocks that can be used to generate code for Union.
> [!NOTE]
> You can see examples by running the dev server and visiting
> [`http://localhost:1313/__docs_builder__/shortcodes/`](`http://localhost:1313/__docs_builder__/shortcodes/`).
> Note that this page is only visible locally. It does not appear in the menus or in the production build.
>
> If you need instructions on how to create the local environment and get the
> `localhost:1313` server running, please refer to the **Contributing docs and examples > Shortcodes > local development guide**.
## How to specify a "shortcode"
The shortcode is a string that is used to generate the HTML that is displayed.
You can specify parameters, when applicable, or have content inside it, if applicable.
> [!NOTE]
> If you specify content, you have to have a close tag.
Examples:
* A shortcode that just outputs something
```markdown
{{* key product_name */>}}
```
* A shortcode that has content inside
```markdown
{{* markdown */>}}
* You markdown
* goes here
{{* /markdown */>}}
```
* A shortcode with parameters
```markdown
{{* link-card target="union-sdk" icon="workflow" title="Union SDK" */>}}
The Union SDK provides the Python API for building Union workflows and apps.
{{* /link-card */>}}
```
> [!NOTE]
> If you're wondering why we have a `{{* markdown */>}}` when we can generate markdown at the top level, it is due to a quirk in Hugo:
> * At the top level of the page, Hugo can render markdown directly, interspersed with shortcodes.
> * However, *inside* a container shortcode, Hugo can only render *either* other shortcodes *or* Markdown.
> * The `{{* markdown */>}}` shortcode is designed to contain only Markdown (not other shortcodes).
> * All other container shortcodes are designed to contain only other shortcodes.
## Variants
The big difference of this site, compared to other documentation sites, is that we generate multiple "flavors" of the documentation that are slightly different from each other. We are calling these "variants."
When you are writing your content, and you want a specific part of the content to be conditional to a flavor, say "BYOC", you surround that with `variant`.
>[!NOTE]
> `variant` is a container, so inside you will specify what you are wrapping.
> You can wrap any of the shortcodes listed in this document.
Example:
```markdown
{{* variant serverless byoc */>}}
{{* markdown */>}}
**The quick brown fox signed up for Union!**
{{* /markdown */>}}
{{* button-link text="Contact Us" target="https://union.ai/contact" */>}}
{{* /variant */>}}
```
## Component Library
### `{{* audio */>}}`
Generates an audio media player.
### `{{* grid */>}}`
Creates a fixed column grid for lining up content.
### `{{* variant */>}}`
Filters content based on which flavor you're seeing.
### `{{* link-card */>}}`
A floating, clickable, navigable card.
### `{{* markdown */>}}`
Generates a markdown block, to be used inside containers such as `{{* dropdown */>}}` or `{{* variant */>}}`.
### `{{* multiline */>}}`
Generates a multiple line, single paragraph. Useful for making a multiline table cell.
### `{{* tabs */>}}` and `{{* tab */>}}`
Generates a tab panel with content switching per tab.
### `{{* key */>}}`
Outputs one of the pre-defined keywords.
Enables inline text that differs per-variant without using the heavy-weight `{{* variant>}}...{{* /variant */>}}` construct.
Take, for example, the following:
```markdown
The {{* key product_name */>}} platform is awesome.
```
In the Flyte variant of the site this will render as:
> The Flyte platform is awesome.
While, in the BYOC, Self-managed and Serverless variants of the site it will render as:
> The Union.ai platform is awesome.
You can add keywords and specify their value, per variant, in `hugo.toml`:
```toml
[params.key.product_full_name]
flyte = "Flyte"
serverless = "Union Serverless"
byoc = "Union BYOC"
selfmanaged = "Union Self-managed"
```
#### List of available keys
| Key | Description | Example Usage (Flyte β Union) |
| ----------------- | ------------------------------------- | ---------------------------------------------------------------------- |
| default_project | Default project name used in examples | `{{* key default_project */>}}` β "flytesnacks" or "default" |
| product_full_name | Full product name | `{{* key product_full_name */>}}` β "Flyte OSS" or "Union.ai Serverless" |
| product_name | Short product name | `{{* key product_name */>}}` β "Flyte" or "Union.ai" |
| product | Lowercase product identifier | `{{* key product */>}}` β "flyte" or "union" |
| kit_name | SDK name | `{{* key kit_name */>}}` β "Flytekit" or "Union" |
| kit | Lowercase SDK identifier | `{{* key kit */>}}` β "flytekit" or "union" |
| kit_as | SDK import alias | `{{* key kit_as */>}}` β "fl" or "union" |
| kit_import | SDK import statement | `{{* key kit_import */>}}` β "flytekit as fl" or "union" |
| kit_remote | Remote client class name | `{{* key kit_remote */>}}` β "FlyteRemote" or "UnionRemote" |
| cli_name | CLI tool name | `{{* key cli_name */>}}` β "Pyflyte" or "Union" |
| cli | Lowercase CLI tool identifier | `{{* key cli */>}}` β "pyflyte" or "union" |
| ctl_name | Control tool name | `{{* key ctl_name */>}}` β "Flytectl" or "Uctl" |
| ctl | Lowercase control tool identifier | `{{* key ctl */>}}` β "flytectl" or "uctl" |
| config_env | Configuration environment variable | `{{* key config_env */>}}` β "FLYTECTL_CONFIG" or "UNION_CONFIG" |
| env_prefix | Environment variable prefix | `{{* key env_prefix */>}}` β "FLYTE" or "UNION" |
| docs_home | Documentation home URL | `{{* key docs_home */>}}` β "/docs/flyte" or "/docs/serverless" |
| map_func | Map function name | `{{* key map_func */>}}` β "map_task" or "map" |
| logo | Logo image filename | `{{* key logo */>}}` β "flyte-logo.svg" or "union-logo.svg" |
| favicon | Favicon image filename | `{{* key favicon */>}}` β "flyte-favicon.ico" or "union-favicon.ico" |
### `{{* download */>}}`
Generates a download link.
Parameters:
- `url`: The URL to download from
- `filename`: The filename to save the file as
- `text`: The text to display for the download link
Example:
```markdown
{{* download "/_static/public/public-key.txt" "public-key.txt" */>}}
```
### `{{* docs_home */>}}`
Produces a link to the home page of the documentation for a specific variant.
Example:
```markdown
[See this in Flyte]({{* docs_home flyte>}}/wherever/you/want/to/go/in/flyte/docs)
```
### `{{* py_class_docsum */>}}`, `{{* py_class_ref */>}}`, and `{{* py_func_ref */>}}`
Helper functions to track Python classes in Flyte documentation, so we can link them to
the appropriate documentation.
Parameters:
- name of the class
- text to add to the link
Example:
```markdown
Please see {{* py_class_ref flyte.core.Image */>}} for more details.
```
### `{{* icon name */>}}`
Uses a named icon in the content.
Example:
```markdown
[Download {{* icon download */>}}](/download)
```
### `{{* code */>}}`
Includes a code snippet or file.
Parameters:
- `file`: The path to the file to include.
- `fragment`: The name of the fragment to include.
- `from`: The line number to start including from.
- `to`: The line number to stop including at.
- `lang`: The language of the code snippet.
- `show_fragments`: Whether to show the fragment names in the code block.
- `highlight`: Whether to highlight the code snippet.
The examples in this section uses this file as base:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment}}
```
*Source: /_static/__docs_builder__/sample.py*
Link to [/_static/__docs_builder__/sample.py](/_static/__docs_builder__/sample.py)
#### Including a section of a file: `{{docs-fragment}}`
```markdown
{{* code file="/_static/__docs_builder__/sample.py" fragment=entrypoint lang=python */>}}
```
Effect:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment}}
```
*Source: /_static/__docs_builder__/sample.py*
#### Including a file with a specific line range: `from` and `to`
```markdown
{{* code file="/_static/__docs_builder__/sample.py" from=2 to=4 lang=python */>}}
```
Effect:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment}}
```
*Source: /_static/__docs_builder__/sample.py*
#### Including a whole file
Simply specify no filters, just the `file` attribute:
```markdown
{{* code file="/_static/__docs_builder__/sample.py" */>}}
```
> [!NOTE]
> Note that without `show_fragments=true` the fragment markers will not be shown.
Effect:
```
def main():
"""
A sample function
"""
return 42
# {{docs-fragment entrypoint}}
if __name__ == "__main__":
main()
# {{/docs-fragment}}
```
*Source: /_static/__docs_builder__/sample.py*
=== PAGE: https://www.union.ai/docs/v2/byoc/community/contributing-docs/redirects ===
# Redirects
We use Cloudflare's Bulk Redirect to map URLs that moved to their new location,
so the user does not get a 404 using the old link.
The direct files are in CSV format, with the following structure:
`,,302,TRUE,FALSE,TRUE,TRUE`
- ``: the URL without `https://`
- ``: the full URL (including `https://`) to send the user to
Redirects are recorded in `redirects.csv` file in the root of the repository.
To take effect, this file must be applied to the production environment on CloudFlare by a Union employee.
If you need to add a new redirect, please create a pull request with the change to `redirect.csv` and a note indicating that you would like to have it applied to production.
## `docs.union.ai` redirects
For redirects from the old `docs.union.ai` site to the new `www.union.ai/docs` site, we use the original request URL. For example:
|
|-|-|
| Request URL | `https://docs.union.ai/administration` |
| Target URL | `/docs/v1/byoc//user-guide/administration` |
| Redirect Entry | `docs.union.ai/administration,/docs/v1/byoc//user-guide/administration,302,TRUE,FALSE,TRUE,TRUE` |
## `docs.flyte.org` redirects
For directs from the old `docs.flyte.org` to the new `www.union.ai/docs`, we replace the `docs.flyte.org` in the request URL with the special prefix `www.union.ai/_r_/flyte`. For example:
|
|-|-|
| Request URL | `https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.dynamic.html` |
| Converted request URL | `www.union.ai/_r_/flyte/projects/flytekit/en/latest/generated/flytekit.dynamic.html` |
| Target URL | `/docs/v1/flyte//api-reference/flytekit-sdk/packages/flytekit.core.dynamic_workflow_task/` |
| Redirect Entry | `www.union.ai/_r_/flyte/projects/flytekit/en/latest/generated/flytekit.dynamic.html,/docs/v1/flyte//api-reference/flytekit-sdk/packages/flytekit.core.dynamic_workflow_task/,302,TRUE,FALSE,TRUE,TRUE` |
The special prefix is used so that we can include both `docs.union.ai` and `docs.flyte.org` redirects in the same file and apply them on the same domain (`www.union.ai`).
=== PAGE: https://www.union.ai/docs/v2/byoc/community/contributing-docs/api-docs ===
# API docs
You can import Python APIs and host them on the site. To do that you will use
the `tools/api_generator` to parse and create the appropriate markdown.
Please refer to [`api_generator/README`](https://github.com/unionai/docs/blob/main/tools/api_generator/README.md) for more details.
## API naming convention
All the buildable APIs are at the root in the form:
`Makefile.api.`
To build it, run `make -f Makefile.api.` and observe the setup
requirements in the `README.md` file above.
## Package Resource Resolution
When scanning the packages we need to know when to include or exclude an object
(class, function, variable) from the documentation. The parser will follow this
workflow to decide, in order, if the resource must be in or out:
1. `__all__: List[str]` package-level variable is present: Only resources
listed will be exposed. All other resources are excluded.
Example:
```python
from http import HTTPStatus, HTTPMethod
__all__ = ["HTTPStatus", "LocalThingy"]
class LocalThingy:
...
class AnotherLocalThingy:
...
```
In this example only `HTTPStatus` and `LocalThingy` will show in the docs.
Both `HTTPMethod` and `AnotherLocalThingy` are ignored.
2. If `__all__` is not present, these rules are observed:
- All imported packages are ignored
- All objects starting with `_` are ignored
Example:
```python
from http import HTTPStatus, HTTPMethod
class _LocalThingy:
...
class AnotherLocalThingy:
...
def _a_func():
...
def b_func():
...
```
In this example only `AnotherLocalThingy` and `b_func` will show in the docs.
Neither none of the imports nor `_LocalThingy` will show in the documentation.
## Tips and Tricks
1. If you either have no resources without a `_` nor an `__all__` to
export blocked resources (imports or starting with `_`, the package will have no content and thus will not be generated.
2. If you want to export something you `from ___ import ____` you _must_
use `__all__` to add the private import to the public list.
3. If all your methods follow the Python convention of everything private starts
with `_` and everything you want public does not, you do not need to have a
`__all__` allow list.
=== PAGE: https://www.union.ai/docs/v2/byoc/community/contributing-docs/publishing ===
# Publishing
## Requirements
1. Hugo (https://gohugo.io/)
```shell
$ brew install hugo
```
2. A preferences override file with your configuration
The tool is flexible and has multiple knobs. Please review `hugo.local.toml~sample`, and configure to meet your preferences.
```shell
$ cp hugo.local.toml~sample hugo.local.toml
```
3. Make sure you review `hugo.local.toml`.
## Managing the Tutorial Pages
The tutorials are maintained in the [unionai/unionai-examples](https://github.com/unionai/unionai-examples) repository and is imported as a git submodule in the `external`
directory.
To initialize the submodule on a fresh clone of this (`docs-builder`) repo, run:
```
$ make init-examples
```
To update the submodule to the latest `main` branch, run:
```
$ make update-examples
```
## Building and running locally
```
$ make dev
```
## Building Production
```
$ make dist
```
### Testing Production Build
You can run a local web server and serve the `dist/` folder. The site must behave correctly, as it would be in its official URL.
To start a server:
```
$ make serve PORT=
```
Example:
```
$ make server PORT=4444
```
Then you open the browser on `http://localhost:` to see the content. In the example above, it would be `http://localhost:4444/`
This will create all the variants into the `dist` folder.
## Developer Experience
This will launch the site in development mode.
The changes are hot reloaded: just change in your favorite editor and it will refresh immediately on the browser.
### Controlling Development Environment
You can change how the development environment works by settings values in `hugo.local.toml`. The following settings are available:
* `variant` - The current variant to display. Change this in 'hugo.toml', save, and the browser will refresh automatically
with the new variant.
* `show_inactive` - If 'true', it will show all the content that did not match the variant.
This is useful when the page contains multiple sections that vary with the selected variant,
so you can see all at once.
* `highlight_active` - If 'true', it will also highlight the *current* content for the variant.
* `highlight_keys` - If 'true'', it highlights replacement keys and their values
### Changing 'variants'
Variants are flavors of the site (that you can change at the top).
During development, you can render any variant by setting it in `hugo.local.toml`:
```
variant = "byoc"
```
We call this the "active" variant.
You can also render variant content from other variants at the same time as well as highlighting the content of your active variant:
To show the content from variants other than the currently active one set:
```
show_inactive = true
```
To highlight the content of the currently active variant (to distinguish it from common content that applies to all variants), set:
```
highlight_active = true
```
> You can create you own copy of `hugo.local.toml` by copying from `hugo.local.toml~sample` to get started.
## Troubleshootting
### Identifying Problems: Missing Content
Content may be hidden due to `{{* variant */>}}` blocks. To see what's missing,
you can adjust the variant show/hide in development mode.
For a production-like look set:
show_inactive = false
highlight_active = false
For a full-developer experience, set:
show_inactive = true
highlight_active = true
### Identifying Problems: Page Visibility
The developer site will show you in red any pages missing from the variant.
For a page to exist in the variant (or be excluded, you have to pick one), it must be listed in the `variants:` at the top of the file.
Clicking on the red page will give you the path you must add to the appropriate variant in the YAML file and a link with guidance.
Please refer to **Contributing docs and examples > Authoring** for more details.
## Building Production
```
$ make dist
```
This will build all the variants and place the result in the `dist` folder.
### Testing Production Build
You can run a local web server and serve the `dist/` folder. The site must behave correctly, as it would be in its official URL.
To start a server:
```
$ make serve [PORT=]
```
If specified without parameters, defaults to PORT=9000.
Example:
```
$ make serve PORT=4444
```
Then you open the browser on `http://localhost:` to see the content. In the example above, it would be `http://localhost:4444/`
=== PAGE: https://www.union.ai/docs/v2/byoc/release-notes ===
# Release Notes
## November 2025
### :fast_forward: Grouped Runs
We redesigned the Runs page to better support large numbers of runs. Historically, large projects produced so many runs that flat listings became difficult to navigate. The new design groups Runs by their root task - leveraging the fact that while there may be millions of runs, there are typically only dozens or hundreds of deployed tasks. This grouped view, combined with enhanced filtering (by status, owner, duration, and more coming soon), makes it dramatically faster and easier to locate the exact runs users are looking for, even in the largest deployments.

### :globe_with_meridians: Apps (beta)
You can now deploy apps in Union 2.0. Apps let you host ML models, Streamlit dashboards, FastAPI services, and other interactive applications alongside your workflows. Simply define your app, deploy it, and Union will handle the infrastructure, routing, and lifecycle management. You can even call apps from your tasks to build end-to-end workflows that combine batch processing with real-time serving.
To create an app, import `flyte` and use either `FastAPIAppEnvironment` for FastAPI applications or the generic `AppEnvironment` for other frameworks. Here's a simple FastAPI example:
```python
from fastapi import FastAPI
import flyte
from flyte.app.extras import FastAPIAppEnvironment
app = FastAPI()
env = FastAPIAppEnvironment(
name="my-api",
app=app,
image=flyte.Image.from_debian_base(python_version=(3, 12))
.with_pip_packages("fastapi", "uvicorn"),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
)
@env.app.get("/greeting/{name}")
async def greeting(name: str) -> str:
return f"Hello, {name}!"
if __name__ == "__main__":
flyte.init_from_config()
flyte.deploy(env) # Deploy and serve your app
```
For Streamlit apps, use the generic `AppEnvironment` with a command:
```python
app_env = flyte.app.AppEnvironment(
name="streamlit-hello-v2",
image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("streamlit==1.41.1"),
command="streamlit hello --server.port 8080",
resources=flyte.Resources(cpu="1", memory="1Gi"),
)
```
You can call apps from tasks by using `depends_on` and making HTTP requests to the app's endpoint. Please refer to the example in the [SDK repo](https://github.com/flyteorg/flyte-sdk/blob/main/examples/apps/call_apps_in_tasks/app.py). Similarly, you can call apps from other apps (see this [example](https://github.com/flyteorg/flyte-sdk/blob/main/examples/apps/app_calling_app/app.py)).
### :label: Custom context
You can now pass configuration and metadata implicitly through your entire task execution hierarchy using custom context. This is ideal for cross-cutting concerns like tracing IDs, experiment metadata, environment information, or logging correlation keysβdata that needs to be available everywhere but isn't logically part of your task's computation.
Custom context is a string key-value map that automatically flows from parent to child tasks without adding parameters to every function signature. Set it once at the run level with `with_runcontext()`, or override values within tasks using the `flyte.custom_context()` context manager:
```python
import flyte
env = flyte.TaskEnvironment("custom-context-example")
@env.task
async def leaf_task() -> str:
# Reads run-level context
print("leaf sees:", flyte.ctx().custom_context)
return flyte.ctx().custom_context.get("trace_id")
@env.task
async def root() -> str:
return await leaf_task()
if __name__ == "__main__":
flyte.init_from_config()
# Base context for the entire run
run = flyte.with_runcontext(custom_context={"trace_id": "root-abc", "experiment": "v1"}).run(root)
print(run.url)
```
### :lock: Secrets UI
Now you can view and create secrets directly from the UI. Secrets are stored securely in your configured secrets manager and injected into your task environments at runtime.

### Image Builds now run in the same project-domain
The image build task is now executed within the same project and domain as the user task, rather than in system-production. This change improves isolation and is a key step toward supporting multi-dataplane clusters.
### Support for secret mounts in Poetry and UV projects
We added support for mounting secrets into both Poetry and UV-based projects. This enables secure access to private dependencies or credentials during image build.
```python
import pathlib
import flyte
env = flyte.TaskEnvironment(
name="uv_project_lib",
resources=flyte.Resources(memory="1000Mi"),
image=(
flyte.Image.from_debian_base().with_uv_project(
pyproject_file=pathlib.Path(__file__).parent / "pyproject.toml",
pre=True,
secret_mounts="my_secret",
)
),
)
```
## October 2025
### :infinity: Larger fanouts
You can now run up to 50,000 actions within a run and up to 1,000 actions concurrently.
To enable observability across so many actions, we added group and sub-actions UI views, which show summary statistics about the actions which were spawned within a group or action.
You can use these summary views (as well as the action status filter) to spot check long-running or failed actions.

### :computer: Remote debugging for Ray head nodes
Rather than locally reproducing errors, sometimes you just want to zoom into the remote execution and see what's happening.
We directly enable this with the debug button.
When you click "Debug action" from an action in a run, we spin up that action's environment, code, and input data, and attach a live VS Code debugger.
Previously, this was only possible with vanilla Python tasks.
Now, you can debug multi-node distributed computations on Ray directly.

### :zap: Triggers and audit history
**Configure tasks > Triggers** let you templatize and set schedules for your workflows, similar to Launch Plans in Flyte 1.0.
```python
@env.task(triggers=flyte.Trigger.hourly()) # Every hour
def example_task(trigger_time: datetime, x: int = 1) -> str:
return f"Task executed at {trigger_time.isoformat()} with x={x}"
```
Once you deploy, it's possible to see all the triggers which are associated with a task:

We also maintain an audit history of every deploy, activation, and deactivation event, so you can get a sense of who's touched an automation.

### :arrow_up: Deployed tasks and input passing
You can see the runs, task spec, and triggers associated with any deployed task, and launch it from the UI. We've converted the launch forms to a convenient JSON Schema syntax, so you can easily copy-paste the inputs from a previous run into a new run for any task.
