Skip to content

Trainer.fit_predict()

Initializes a training job on the provided dataset, then immediately performs inference on the dataset’s test table using the trained model. Jobs initiated by the fit_predict() method are of type train_inference.

You can monitor the training process within Snowflake ML Experiments under the experiment_name, which uniquely identifies the dataset used for training. This name is automatically generated when creating the Dataset in the format: dataset_name_task_type_task_name.

NameTypeDescriptionOptional
datasetDatasetThe dataset object used to train the model on its associated task.No
output_aliasstrAn alias to append to the result tables of predictions and embeddings.No
output_configOutputConfigConfiguration specifying where to save results.No
materialize_resultsboolIf materialize_results is True, the method waits for the job to finish and then creates permanent tables. This makes the method run synchronously (blocking) instead of executing in the background. Defaults to True.Yes
test_batch_sizeintTest batch size to use during inference. Defaults to 128Yes
extract_embeddingsboolIf True, extract node embeddings. For node tasks and repeated link prediction tasks, source node embeddings are extracted. For link prediction tasks, both source and target node embeddings are extracted. Defaults to FalseYes

An instance of a JobMonitor object, to monitor the submitted job, check its status and retrieve results after inference.

Running Training and Inference with materialize_results = True

Section titled “Running Training and Inference with materialize_results = True”

Before running inference, you must specify where the output data will be saved. The data is stored in Snowflake tables. Ensure that the engine has write permissions for the target database and schema. By default materialize_results is set to True, which means the job will run training and inference synchronously (blocking) rather than executing in the background.

from relationalai_gnns import OutputConfig
output_config = OutputConfig.snowflake(
database_name="DATABASE_NAME",
schema_name="PUBLIC"
)
train_inference_job = trainer.fit_predict(
output_alias="EXP_ALIAS",
output_config=output_config,
dataset=dataset,
extract_embeddings=True
)

Running Training and Inference With materialize_results = False

Section titled “Running Training and Inference With materialize_results = False”

You can set materialize_results to False when calling fit_predict() and later run the materialize_results() method to export the generated predictions (and embeddings).

train_inference_job_2 = trainer.predict(
output_alias="EXP_ALIAS",
output_config=output_config,
dataset=dataset,
extract_embeddings=True,
materialize_results=False
)
train_inference_job_2.materialize_results()