# PyTorch

The PyTorch plugin lets you run distributed [PyTorch](https://pytorch.org/) training jobs natively on Kubernetes. It uses the [Kubeflow Training Operator](https://github.com/kubeflow/training-operator) to manage multi-node training with PyTorch's elastic launch (`torchrun`).

## When to use this plugin

- Single-node or multi-node distributed training with `DistributedDataParallel` (DDP)
- Elastic training that can scale up and down during execution
- Any workload that uses `torch.distributed` for data-parallel or model-parallel training

## Installation

```bash
pip install flyteplugins-pytorch
```

## Configuration

Create an `Elastic` configuration and pass it as `plugin_config` to a `TaskEnvironment`:

```python
from flyteplugins.pytorch import Elastic

torch_env = flyte.TaskEnvironment(
    name="torch_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
    plugin_config=Elastic(
        nnodes=2,
        nproc_per_node=1,
    ),
    image=image,
)
```

### `Elastic` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `nnodes` | `int` or `str` | **Required.** Number of nodes. Use an int for a fixed count or a range string (e.g., `"2:4"`) for elastic training |
| `nproc_per_node` | `int` | **Required.** Number of processes (workers) per node |
| `rdzv_backend` | `str` | Rendezvous backend: `"c10d"` (default), `"etcd"`, or `"etcd-v2"` |
| `max_restarts` | `int` | Maximum worker group restarts (default: `3`) |
| `monitor_interval` | `int` | Agent health check interval in seconds (default: `3`) |
| `run_policy` | `RunPolicy` | Job run policy (cleanup, TTL, deadlines, retries) |

### `RunPolicy` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `clean_pod_policy` | `str` | Pod cleanup policy: `"None"`, `"all"`, or `"Running"` |
| `ttl_seconds_after_finished` | `int` | Seconds to keep pods after job completion |
| `active_deadline_seconds` | `int` | Maximum time the job can run (seconds) |
| `backoff_limit` | `int` | Number of retries before marking the job as failed |

### NCCL tuning parameters

The plugin includes built-in NCCL timeout tuning to reduce failure-detection latency (PyTorch defaults to 1800 seconds):

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `nccl_heartbeat_timeout_sec` | `int` | `300` | NCCL heartbeat timeout (seconds) |
| `nccl_async_error_handling` | `bool` | `False` | Enable async NCCL error handling |
| `nccl_collective_timeout_sec` | `int` | `None` | Timeout for NCCL collective operations |
| `nccl_enable_monitoring` | `bool` | `True` | Enable NCCL monitoring |

### Writing a distributed training task

Tasks using this plugin do not need to be `async`. Initialize the process group and use `DistributedDataParallel` as you normally would with `torchrun`:

```python
import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel as DDP

@torch_env.task
def train(epochs: int) -> float:
    torch.distributed.init_process_group("gloo")
    model = DDP(MyModel())
    # ... training loop ...
    return final_loss
```

> [!NOTE]
> When `nnodes=1`, the task runs as a regular Python task (no Kubernetes training job is created). Set `nnodes >= 2` for multi-node distributed training.

## Example

```python
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-pytorch",
#    "torch"
# ]
# main = "torch_distributed_train"
# params = "3"
# ///

import typing

import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from flyteplugins.pytorch.task import Elastic
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset

import flyte

image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch", pre=True)

torch_env = flyte.TaskEnvironment(
    name="torch_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
    plugin_config=Elastic(
        nproc_per_node=1,
        # if you want to do local testing set nnodes=1
        nnodes=2,
    ),
    image=image,
)

class LinearRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

def prepare_dataloader(rank: int, world_size: int, batch_size: int = 2) -> DataLoader:
    """
    Prepare a DataLoader with a DistributedSampler so each rank
    gets a shard of the dataset.
    """
    # Dummy dataset
    x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
    y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])
    dataset = TensorDataset(x_train, y_train)

    # Distributed-aware sampler
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)

    return DataLoader(dataset, batch_size=batch_size, sampler=sampler)

def train_loop(epochs: int = 3) -> float:
    """
    A simple training loop for linear regression.
    """
    torch.distributed.init_process_group("gloo")
    model = DDP(LinearRegressionModel())

    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()

    dataloader = prepare_dataloader(
        rank=rank,
        world_size=world_size,
        batch_size=64,
    )

    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    final_loss = 0.0

    for _ in range(epochs):
        for x, y in dataloader:
            outputs = model(x)
            loss = criterion(outputs, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            final_loss = loss.item()
        if torch.distributed.get_rank() == 0:
            print(f"Loss: {final_loss}")

    return final_loss

@torch_env.task
def torch_distributed_train(epochs: int) -> typing.Optional[float]:
    """
    A nested task that sets up a simple distributed training job using PyTorch's
    """
    print("starting launcher")
    loss = train_loop(epochs=epochs)
    print("Training complete")
    return loss

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(torch_distributed_train, epochs=3)
    print(r.name)
    print(r.url)
    r.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/pytorch/pytorch_example.py*

## API reference

See the [PyTorch API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/pytorch/_index) for full details.

---
**Source**: https://github.com/unionai/unionai-docs/blob/main/content/integrations/pytorch/_index.md
**HTML**: https://www.union.ai/docs/v2/union/integrations/pytorch/
