Documentation
Custom Serverless Functions on GPUs
Avoiding Model Reloads from Disk

Avoiding model reloads from disk

When working with ML workloads, it is important to avoid reloading the same model into memory with each subsequent invocation of the serverless function. For this, fal offers a generic and serverless-aware caching system that can be paired with the keep_alive property to ensure that if you hit the same function within the kept-alive time frame, you can avoid paying any overhead at all.

import fal
 
# Can be any model from HF's model hub, see https://huggingface.co/models
TEXT_CLASSIFICATION_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"
 
@fal.cached
def get_model(model_name: str = TEXT_CLASSIFICATION_MODEL) -> object:
    from transformers import pipeline
 
    return pipeline("text-classification", model=model_name)
 
@fal.function(
    "virtualenv",
    requirements=["transformers", "datasets", "torch"],
    machine_type="M",
    keep_alive=60,
)
def classify_text(text: str) -> tuple[str, float]:
    pipe = get_model()
    [result] = pipe(text)
    return result["label"], result["score"]

We can move the model loading logic inside the cached get_model function and call it when we are doing inference. The first invocation on a new machine will still incur the cold starting cost, but all subsequent invocations will be almost free in terms of model loading. This is already noticeable in the major reduction in runtime cost, from approximately 3 seconds to approximately 0.6 seconds.


2023 © Features and Labels Inc.