Skip to content

inference

Module that defines different ways to run inference with TabPFN.

InferenceEngine dataclass

Bases: ABC

These define how tabpfn inference can be run.

As there are many things that can be cached, with multiple ways to parallelize, Executor defines three primary things:

Most will define a method prepare() which is specific to that inference engine. These do not share a common interface.

  1. What to cache:

    As we can prepare a lot of the transformers context, there is a tradeoff in terms of how much memory to be spent in caching. This memory is used when prepare() is called, usually in fit().

  2. Using the cached data for inference:

    Based on what has been prepared for the transformer context, iter_outputs() will use this cached information to make predictions.

  3. Controlling parallelism:

    As we have trivially parallel parts for inference, we can parallelize them. However as the GPU is typically a bottle-neck in most systems, we can define, where and how we would like to parallelize the inference.

iter_outputs abstractmethod

iter_outputs(
    X: ndarray, *, device: device, autocast: bool
) -> Iterator[tuple[Tensor, EnsembleConfig]]

Iterate over the outputs of the model.

One for each ensemble configuration that was used to initialize the executor.

Parameters:

Name Type Description Default
X ndarray

The input data to make predictions on.

required
device device

The device to run the model on.

required
autocast bool

Whether to use torch.autocast during inference.

required

InferenceEngineCacheKV dataclass

Bases: InferenceEngine

Inference engine that caches the actual KV cache calculated from the context of the processed training data.

This is by far the most memory intensive inference engine, as for each ensemble member we store the full KV cache of that model. For now this is held in CPU RAM (TODO(eddiebergman): verify)

prepare classmethod

prepare(
    X_train: ndarray,
    y_train: ndarray,
    *,
    cat_ix: list[int],
    ensemble_configs: Sequence[EnsembleConfig],
    n_workers: int,
    model: PerFeatureTransformer,
    device: device,
    rng: Generator,
    dtype_byte_size: int,
    force_inference_dtype: dtype | None,
    save_peak_mem: bool | Literal["auto"] | float | int,
    autocast: bool
) -> InferenceEngineCacheKV

Prepare the inference engine.

Parameters:

Name Type Description Default
X_train ndarray

The training data.

required
y_train ndarray

The training target.

required
cat_ix list[int]

The categorical indices.

required
ensemble_configs Sequence[EnsembleConfig]

The ensemble configurations to use.

required
n_workers int

The number of workers to use.

required
model PerFeatureTransformer

The model to use.

required
device device

The device to run the model on.

required
rng Generator

The random number generator.

required
dtype_byte_size int

Size of the dtype in bytes.

required
force_inference_dtype dtype | None

The dtype to force inference to.

required
save_peak_mem bool | Literal['auto'] | float | int

Whether to save peak memory usage.

required
autocast bool

Whether to use torch.autocast during inference.

required

InferenceEngineCachePreprocessing dataclass

Bases: InferenceEngine

Inference engine that caches the preprocessing for feeding as model context on predict.

This will fit the preprocessors on the training data, as well as cache the transformed training data on RAM (not GPU RAM).

This saves some time on each predict call, at the cost of increasing the amount of memory in RAM. The main functionality performed at predict() time is to forward pass through the model which is currently done sequentially.

prepare classmethod

prepare(
    X_train: ndarray,
    y_train: ndarray,
    *,
    cat_ix: list[int],
    model: PerFeatureTransformer,
    ensemble_configs: Sequence[EnsembleConfig],
    n_workers: int,
    rng: Generator,
    dtype_byte_size: int,
    force_inference_dtype: dtype | None,
    save_peak_mem: bool | Literal["auto"] | float | int
) -> InferenceEngineCachePreprocessing

Prepare the inference engine.

Parameters:

Name Type Description Default
X_train ndarray

The training data.

required
y_train ndarray

The training target.

required
cat_ix list[int]

The categorical indices.

required
model PerFeatureTransformer

The model to use.

required
ensemble_configs Sequence[EnsembleConfig]

The ensemble configurations to use.

required
n_workers int

The number of workers to use.

required
rng Generator

The random number generator.

required
dtype_byte_size int

The byte size of the dtype.

required
force_inference_dtype dtype | None

The dtype to force inference to.

required
save_peak_mem bool | Literal['auto'] | float | int

Whether to save peak memory usage.

required

Returns:

Type Description
InferenceEngineCachePreprocessing

The prepared inference engine.

InferenceEngineOnDemand dataclass

Bases: InferenceEngine

Inference engine that does not cache anything, computes everything as needed.

This is one of the slowest ways to run inference, as computation that could be cached is recomputed on every call. However the memory demand is lowest and can be more trivially parallelized across GPUs with some work.

prepare classmethod

prepare(
    X_train: ndarray,
    y_train: ndarray,
    *,
    cat_ix: list[int],
    model: PerFeatureTransformer,
    ensemble_configs: Sequence[EnsembleConfig],
    rng: Generator,
    n_workers: int,
    dtype_byte_size: int,
    force_inference_dtype: dtype | None,
    save_peak_mem: bool | Literal["auto"] | float | int
) -> InferenceEngineOnDemand

Prepare the inference engine.

Parameters:

Name Type Description Default
X_train ndarray

The training data.

required
y_train ndarray

The training target.

required
cat_ix list[int]

The categorical indices.

required
model PerFeatureTransformer

The model to use.

required
ensemble_configs Sequence[EnsembleConfig]

The ensemble configurations to use.

required
rng Generator

The random number generator.

required
n_workers int

The number of workers to use.

required
dtype_byte_size int

The byte size of the dtype.

required
force_inference_dtype dtype | None

The dtype to force inference to.

required
save_peak_mem bool | Literal['auto'] | float | int

Whether to save peak memory usage.

required