loading ¶
download_model ¶
download_model(
to: Path,
*,
version: Literal["v2"],
which: Literal["classifier", "regressor"],
model_name: str | None = None
) -> Literal["ok"] | list[Exception]
Download a TabPFN model, trying all available sources.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
to |
Path
|
The directory to download the model to. |
required |
version |
Literal['v2']
|
The version of the model to download. |
required |
which |
Literal['classifier', 'regressor']
|
The type of model to download. |
required |
model_name |
str | None
|
Optional specific model name to download. |
None
|
Returns:
Type | Description |
---|---|
Literal['ok'] | list[Exception]
|
"ok" if the model was downloaded successfully, otherwise a list of |
Literal['ok'] | list[Exception]
|
exceptions that occurred that can be handled as desired. |
load_model ¶
load_model(*, path: Path, model_seed: int) -> tuple[
PerFeatureTransformer,
BCEWithLogitsLoss
| CrossEntropyLoss
| FullSupportBarDistribution,
InferenceConfig,
]
Loads a model from a given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Path
|
Path to the checkpoint |
required |
model_seed |
int
|
The seed to use for the model |
required |