Learn how to stream real-time results during distributed inference and training
Streaming allows you to send intermediate results from your distributed workers back to the client in real-time. This is particularly useful for long-running operations like image generation, video creation, or model training where users benefit from seeing progress updates.
For a complete working example of streaming with multi-GPU inference, see the Parallel SDXL Tutorial.
In your DistributedWorker, use add_streaming_result() to send intermediate results:
from fal.distributed import DistributedWorkerimport torch.distributed as distclass StreamingWorker(DistributedWorker): def __call__(self, prompt: str, steps: int = 20): for step in range(steps): # Do some processing result = self.model.step(prompt) # Only rank 0 streams to avoid duplicates if self.rank == 0: self.add_streaming_result({ "step": step, "progress": (step + 1) / steps, "message": f"Processing step {step + 1}/{steps}" }, as_text_event=True) # Return final result return {"output": final_result}
Key points:
add_streaming_result(): Sends data to the client
as_text_event=True: Formats as Server-Sent Events (SSE)
Only rank 0 should stream to avoid duplicate messages
const response = await fetch('https://your-app.fal.run/stream', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ prompt: "A sunset", steps: 20 })});const reader = response.body.getReader();const decoder = new TextDecoder();while (true) { const { done, value } = await reader.read(); if (done) break; const text = decoder.decode(value); const events = text.split('\n\n'); for (const event of events) { if (event.startsWith('data: ')) { const data = JSON.parse(event.slice(6)); console.log(`Step ${data.step}: ${data.progress * 100}%`); } }}
Python:
import fal_clientfor event in fal_client.stream( "username/app-name", arguments={"prompt": "A sunset", "steps": 20} # path="/stream" # Optional: defaults to "/stream", change if your endpoint uses a different path): print(f"Step {event['step']}: {event['progress'] * 100}%")
If your endpoint uses a path other than /stream, specify it with the path parameter to match your @fal.endpoint() decorator.
Stream intermediate results from all GPUs and combine them:
class MultiGPUStreamingWorker(DistributedWorker): def __call__(self, prompt: str, num_steps: int = 20): for step in range(0, num_steps, 5): # Stream every 5 steps # Generate intermediate result on this GPU intermediate = self.model.step(prompt) # Gather from all workers if self.rank == 0: gather_list = [ torch.zeros_like(intermediate, device=self.device) for _ in range(self.world_size) ] else: gather_list = None dist.gather(intermediate, gather_list, dst=0) # Only rank 0 streams the combined result if self.rank == 0: combined = self.combine_results(gather_list) self.add_streaming_result({ "step": step, "preview": combined, "num_gpus": self.world_size, }, as_text_event=True) # Synchronize before next step dist.barrier() return {"final": final_result}
# Good: Small progress updatesself.add_streaming_result({ "step": step, "progress": 0.5,})# Avoid: Large data in every update# self.add_streaming_result({"large_array": [...]})