Essential API reference for fal.distributed: the key methods you need to build multi-GPU applications
This page covers the essential APIs for building multi-GPU applications with fal.distributed. We focus on the methods you’ll actually use in your code.
The DistributedRunner class orchestrates multiple GPU workers for distributed computation. It handles process management, inter-process communication via ZMQ, and coordination between worker processes.
timeout (int): Maximum time (in seconds) to wait for all workers to be ready. Default: 1800 (30 minutes).
**kwargs: Additional keyword arguments passed to each worker’s setup() method.
Raises:
RuntimeError: If processes are already running or fail to start.
TimeoutError: If workers don’t become ready within the timeout period.
Example:
class MyApp(fal.App): num_gpus = 2 async def setup(self): self.runner = DistributedRunner( worker_cls=MyWorker, world_size=self.num_gpus, ) # Start workers and pass model path to their setup() await self.runner.start(model_path="/data/models/flux") # Workers are now ready to process requests
What it does:
Spawns world_size worker processes (one per GPU)
Each worker runs its setup() method with the provided **kwargs
Waits for all workers to signal “READY”
Starts the keepalive timer if configured
Returns when all workers are initialized and ready
This method must be called before using invoke() or stream(). It’s typically called once in your app’s setup() method.
payload (dict[str, Any]): Dictionary of arguments to pass to each worker’s __call__() method. Default: {}.
timeout (int | None): Maximum total time (in seconds) for the entire operation. Default: None (no limit).
streaming_timeout (int | None): Maximum time (in seconds) between consecutive yields. If no data is received within this period, raises TimeoutError. Default: None.
as_text_events (bool): If True, yields Server-Sent Events (SSE) formatted as bytes. If False, yields deserialized Python objects. Default: False.
Returns:
AsyncIterator[Any]: Async iterator yielding intermediate results and the final result.
Raises:
RuntimeError: If workers are not running, encounter an error, or yield no data.
TimeoutError: If the operation exceeds timeout or streaming_timeout.
The DistributedWorker class is the base class for your custom GPU workers. Each instance runs on a separate GPU and handles model loading, inference, or training.Create your own worker by inheriting from DistributedWorker and overriding the setup() and __call__() methods.
class MyWorker(DistributedWorker): def setup(self, **kwargs): # Load model on this GPU self.model = load_model().to(self.device) def __call__(self, prompt: str, **kwargs): # Process request return self.model.generate(prompt)
torch.device: The PyTorch device for this worker, e.g., cuda:0, cuda:1, etc.
Example:
class MyWorker(DistributedWorker): def setup(self): # Load model on this worker's GPU self.model = MyModel().to(self.device) print(f"Model loaded on {self.device}")
Called once when the worker is initialized. Use this to load models, download weights, and prepare resources.
def setup(self, **kwargs: Any) -> None
Parameters:
**kwargs: Any keyword arguments passed to runner.start().
Example:
class FluxWorker(DistributedWorker): def setup(self, model_path: str = "/data/flux", **kwargs): """Initialize the Flux model on this GPU""" import torch from diffusers import FluxPipeline self.rank_print(f"Loading Flux on {self.device}") self.pipeline = FluxPipeline.from_pretrained( model_path, torch_dtype=torch.bfloat16, ).to(self.device) # Disable progress bar for non-main workers if self.rank != 0: self.pipeline.set_progress_bar_config(disable=True) self.rank_print("Model loaded successfully")
Heavy operations like model loading should go in setup(), not __call__(), so they only happen once per worker.
class ParallelWorker(DistributedWorker): def __call__(self, prompt: str, **kwargs): import torch.distributed as dist # Each GPU generates independently with different seed result = self.model.generate(prompt) # Gather all results to rank 0 if self.rank == 0: gather_list = [torch.zeros_like(result) for _ in range(self.world_size)] else: gather_list = None dist.gather(result, gather_list, dst=0) # Only rank 0 returns combined result if self.rank == 0: return {"outputs": gather_list} return {}
All GPUs have the same model, process different batches, and sync gradients:
class DDPWorker(DistributedWorker): def setup(self, **kwargs): from torch.nn.parallel import DistributedDataParallel as DDP self.model = MyModel().to(self.device) # Wrap with DDP for gradient synchronization self.model = DDP( self.model, device_ids=[self.rank], output_device=self.rank, ) self.optimizer = torch.optim.Adam(self.model.parameters()) def __call__(self, data_path: str, **kwargs): import torch.distributed as dist # Load and distribute data if self.rank == 0: data = load_data(data_path) else: data = None # Broadcast to all ranks data = dist.broadcast_object_list([data], src=0)[0] # Each GPU processes different batch local_batch = data[self.rank::self.world_size] # Training loop for batch in local_batch: loss = self.model(batch) loss.backward() # DDP syncs gradients automatically self.optimizer.step() self.optimizer.zero_grad() # Only rank 0 saves checkpoint if self.rank == 0: torch.save(self.model.state_dict(), "checkpoint.pt") return {"checkpoint": "checkpoint.pt"} return {}
Stream intermediate results during long-running operations:
class StreamingWorker(DistributedWorker): def __call__(self, prompt: str, steps: int = 50, streaming: bool = False): import torch.distributed as dist for step in range(steps): result = self.model.step(prompt) # Stream progress every 5 steps if streaming and self.rank == 0 and step % 5 == 0: self.add_streaming_result({ "step": step, "progress": step / steps, }, as_text_event=True) # Sync all workers dist.barrier() # Return final result if self.rank == 0: return {"output": result} return {}