Trainer.fit_predict()
Initializes a training job on the provided dataset, then immediately performs inference on the dataset’s test table using the trained model. Jobs initiated by the fit_predict() method are of type train_inference.
You can monitor the training process within Snowflake ML Experiments under the experiment_name, which uniquely identifies the dataset used for training. This name is automatically generated when creating the Dataset in the format: dataset_name_task_type_task_name.
Parameters
Section titled “Parameters”| Name | Type | Description | Optional |
|---|---|---|---|
dataset | Dataset | The dataset object used to train the model on its associated task. | No |
output_alias | str | An alias to append to the result tables of predictions and embeddings. | No |
output_config | OutputConfig | Configuration specifying where to save results. | No |
materialize_results | bool | If materialize_results is True, the method waits for the job to finish and then creates permanent tables. This makes the method run synchronously (blocking) instead of executing in the background. Defaults to True. | Yes |
test_batch_size | int | Test batch size to use during inference. Defaults to 128 | Yes |
extract_embeddings | bool | If True, extract node embeddings. For node tasks and repeated link prediction tasks, source node embeddings are extracted. For link prediction tasks, both source and target node embeddings are extracted. Defaults to False | Yes |
Returns
Section titled “Returns”An instance of a JobMonitor object, to monitor the submitted job, check its status and retrieve results after inference.
Examples
Section titled “Examples”Running Training and Inference with materialize_results = True
Section titled “Running Training and Inference with materialize_results = True”Before running inference, you must specify where the output data will be saved. The data is stored in Snowflake tables. Ensure that the engine has write permissions for the target database and schema. By default materialize_results is set to True, which means the job will run training and inference synchronously (blocking) rather than executing in the background.
from relationalai_gnns import OutputConfig
output_config = OutputConfig.snowflake( database_name="DATABASE_NAME", schema_name="PUBLIC")train_inference_job = trainer.fit_predict( output_alias="EXP_ALIAS", output_config=output_config, dataset=dataset, extract_embeddings=True)Running Training and Inference With materialize_results = False
Section titled “Running Training and Inference With materialize_results = False”You can set materialize_results to False when calling fit_predict() and later run the materialize_results() method to export the generated predictions (and embeddings).
train_inference_job_2 = trainer.predict( output_alias="EXP_ALIAS", output_config=output_config, dataset=dataset, extract_embeddings=True, materialize_results=False)
train_inference_job_2.materialize_results()