inference ¶
Module that defines different ways to run inference with TabPFN.
InferenceEngine
dataclass
¶
Bases: ABC
These define how tabpfn inference can be run.
As there are many things that can be cached, with multiple ways to parallelize,
Executor
defines three primary things:
Most will define a method prepare()
which is specific to that inference engine.
These do not share a common interface.
-
What to cache:
As we can prepare a lot of the transformers context, there is a tradeoff in terms of how much memory to be spent in caching. This memory is used when
prepare()
is called, usually infit()
. -
Using the cached data for inference:
Based on what has been prepared for the transformer context,
iter_outputs()
will use this cached information to make predictions. -
Controlling parallelism:
As we have trivially parallel parts for inference, we can parallelize them. However as the GPU is typically a bottle-neck in most systems, we can define, where and how we would like to parallelize the inference.
iter_outputs
abstractmethod
¶
iter_outputs(
X: ndarray, *, device: device, autocast: bool
) -> Iterator[tuple[Tensor, EnsembleConfig]]
Iterate over the outputs of the model.
One for each ensemble configuration that was used to initialize the executor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X |
ndarray
|
The input data to make predictions on. |
required |
device |
device
|
The device to run the model on. |
required |
autocast |
bool
|
Whether to use torch.autocast during inference. |
required |
InferenceEngineCacheKV
dataclass
¶
Bases: InferenceEngine
Inference engine that caches the actual KV cache calculated from the context of the processed training data.
This is by far the most memory intensive inference engine, as for each ensemble member we store the full KV cache of that model. For now this is held in CPU RAM (TODO(eddiebergman): verify)
prepare
classmethod
¶
prepare(
X_train: ndarray,
y_train: ndarray,
*,
cat_ix: list[int],
ensemble_configs: Sequence[EnsembleConfig],
n_workers: int,
model: PerFeatureTransformer,
device: device,
rng: Generator,
dtype_byte_size: int,
force_inference_dtype: dtype | None,
save_peak_mem: bool | Literal["auto"] | float | int,
autocast: bool
) -> InferenceEngineCacheKV
Prepare the inference engine.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X_train |
ndarray
|
The training data. |
required |
y_train |
ndarray
|
The training target. |
required |
cat_ix |
list[int]
|
The categorical indices. |
required |
ensemble_configs |
Sequence[EnsembleConfig]
|
The ensemble configurations to use. |
required |
n_workers |
int
|
The number of workers to use. |
required |
model |
PerFeatureTransformer
|
The model to use. |
required |
device |
device
|
The device to run the model on. |
required |
rng |
Generator
|
The random number generator. |
required |
dtype_byte_size |
int
|
Size of the dtype in bytes. |
required |
force_inference_dtype |
dtype | None
|
The dtype to force inference to. |
required |
save_peak_mem |
bool | Literal['auto'] | float | int
|
Whether to save peak memory usage. |
required |
autocast |
bool
|
Whether to use torch.autocast during inference. |
required |
InferenceEngineCachePreprocessing
dataclass
¶
Bases: InferenceEngine
Inference engine that caches the preprocessing for feeding as model context on predict.
This will fit the preprocessors on the training data, as well as cache the transformed training data on RAM (not GPU RAM).
This saves some time on each predict call, at the cost of increasing the amount
of memory in RAM. The main functionality performed at predict()
time is to
forward pass through the model which is currently done sequentially.
prepare
classmethod
¶
prepare(
X_train: ndarray,
y_train: ndarray,
*,
cat_ix: list[int],
model: PerFeatureTransformer,
ensemble_configs: Sequence[EnsembleConfig],
n_workers: int,
rng: Generator,
dtype_byte_size: int,
force_inference_dtype: dtype | None,
save_peak_mem: bool | Literal["auto"] | float | int
) -> InferenceEngineCachePreprocessing
Prepare the inference engine.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X_train |
ndarray
|
The training data. |
required |
y_train |
ndarray
|
The training target. |
required |
cat_ix |
list[int]
|
The categorical indices. |
required |
model |
PerFeatureTransformer
|
The model to use. |
required |
ensemble_configs |
Sequence[EnsembleConfig]
|
The ensemble configurations to use. |
required |
n_workers |
int
|
The number of workers to use. |
required |
rng |
Generator
|
The random number generator. |
required |
dtype_byte_size |
int
|
The byte size of the dtype. |
required |
force_inference_dtype |
dtype | None
|
The dtype to force inference to. |
required |
save_peak_mem |
bool | Literal['auto'] | float | int
|
Whether to save peak memory usage. |
required |
Returns:
Type | Description |
---|---|
InferenceEngineCachePreprocessing
|
The prepared inference engine. |
InferenceEngineOnDemand
dataclass
¶
Bases: InferenceEngine
Inference engine that does not cache anything, computes everything as needed.
This is one of the slowest ways to run inference, as computation that could be cached is recomputed on every call. However the memory demand is lowest and can be more trivially parallelized across GPUs with some work.
prepare
classmethod
¶
prepare(
X_train: ndarray,
y_train: ndarray,
*,
cat_ix: list[int],
model: PerFeatureTransformer,
ensemble_configs: Sequence[EnsembleConfig],
rng: Generator,
n_workers: int,
dtype_byte_size: int,
force_inference_dtype: dtype | None,
save_peak_mem: bool | Literal["auto"] | float | int
) -> InferenceEngineOnDemand
Prepare the inference engine.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
X_train |
ndarray
|
The training data. |
required |
y_train |
ndarray
|
The training target. |
required |
cat_ix |
list[int]
|
The categorical indices. |
required |
model |
PerFeatureTransformer
|
The model to use. |
required |
ensemble_configs |
Sequence[EnsembleConfig]
|
The ensemble configurations to use. |
required |
rng |
Generator
|
The random number generator. |
required |
n_workers |
int
|
The number of workers to use. |
required |
dtype_byte_size |
int
|
The byte size of the dtype. |
required |
force_inference_dtype |
dtype | None
|
The dtype to force inference to. |
required |
save_peak_mem |
bool | Literal['auto'] | float | int
|
Whether to save peak memory usage. |
required |