Trainer.predict()
Submits an inference job using a specified model. Jobs initiated by the predict() method are of type inference and can be monitored using the returned JobMonitor object.
Parameters Requirements:
Input Source:
Inference supports three input scenarios:
dataset: A newdatasetwith the same schema used during training is provided. The tables in the provided dataset are used to construct the graph, and inference is performed on the dataset’s test table. Therefore, the provideddatasetmust include a test table.test_table: The dataset the model was trained on is used to construct the graph, and inference is performed on the providedtest_table.- Neither
datasetnortest_tableprovided: The dataset the model was trained on is used, and inference is performed on its test table.
Test Table Behavior:
The test_table argument is only valid when no dataset is provided. When specified, it overrides the default test table defined in the dataset the model was trained on. This is useful for running inference on a new test table source without creating a new Dataset object.
Model Selection:
A model must be specified using one of the following options:
registered_model_key: Use a specific registered model. The convention followed for the name isdatabase_name.schema_name.registered_model_name.version, wheredatabase_nameis the name of the database that the model is saved,schema_nameis the name of the schema in the database that the model is saved,registered_model_nameis the name of the registered model andversionis the version of the registered model.model_run_id: Use a specific model run ID.
Parameters
Section titled “Parameters”| Name | Type | Description | Optional |
|---|---|---|---|
output_alias | str | An alias to append to the result tables of predictions and embeddings. This allows you to differentiate between multiple inference runs. | No |
output_config | OutputConfig | Configuration specifying where to save results. | No |
materialize_results | bool | If materialize_results is True, the method waits for the job to finish and then creates permanent tables. This makes the method run synchronously (blocking) instead of executing in the background. Defaults to True. | Yes |
test_batch_size | int | Test batch size to use during inference. Defaults to 128. | Yes |
dataset | Dataset | Dataset to run inference on. | Yes |
test_table | str | Fully qualified path to a test table. Uses the training graph with this table replacing the original test split (if existed). Must contain the required entity columns (and time column if used during training). The label column is optional. Cannot be used together with dataset. | Yes |
model_run_id | str | Run inference using this specific model run ID. You can obtain the model run id from your training job using train_job.model_run_id. Either a model_run_id or a registered_model_key must be provided. | Yes |
registered_model_key | str | The full identifier of a registered model in the format: database_name.schema_name.registered_model_name.version. Either a model_run_id or a registered_model_key must be provided. | Yes |
extract_embeddings | bool | If True, extract node embeddings. For node tasks and repeated link prediction tasks, source node embeddings are extracted. For link prediction tasks, both source and target node embeddings are extracted. Defaults to False. | Yes |
align_dtypes | bool | Controls how column data type mismatches are handled: - False: Raises a ValidationError if column data types in the provided dataset or test_table do not match those of the training dataset. - True: Automatically aligns column dtypes in the provided dataset or test_table with the model’s expected types. Defaults to False. | Yes |
Returns
Section titled “Returns”An instance of a JobMonitor object, to monitor the submitted job, check its status and retrieve results after inference.
Examples
Section titled “Examples”Running Inference With Model Run ID
Section titled “Running Inference With Model Run ID”Before running inference, you must specify where the output data will be saved. The data is stored in Snowflake tables. Ensure that the engine has write permissions for the target database and schema.
from relationalai_gnns import OutputConfig
output_config = OutputConfig.snowflake( database_name="DATABASE_NAME", schema_name="PUBLIC")We provide several ways to run inference. In the simplest scenario, you can use a model that has just been trained. Each trained model has a unique identifier: model_run_id, provided by Snowflake experiment tracking. After training completes (you can verify this via the job’s status), you can access the model’s run ID with train_job.model_run_id, where train_job is an instance of a JobMonitor object.
Example of running inference using a trained model:
inference_job = trainer.predict( output_alias="EXPERIMENT_1", output_config=output_config, test_batch_size=128, model_run_id=train_job.model_run_id, extract_embeddings=True)Note: In the example above inference is run on the original dataset that the model was trained on since no new dataset or train_table arguments are specified.
By default, the engine will write the predictions to a table named PREDICTIONS_{output_alias} in the schema defined in the OutputConfig. In this example the predictions will be stored in a table named:
DATABASE_NAME.PUBLIC.PREDICTIONS_EXPERIMENT_1The output_alias (e.g., EXPERIMENT_1) helps differentiate results from multiple inference jobs. The application is not permitted to overwrite existing tables. If a table with the same alias already exists, an error is raised.
Embedding Extraction
When performing inference, you can also extract embeddings by setting extract_embeddings=True. For:
- Node classification: embeddings are returned for the target entity column (as defined in the
NodeTask) - Repeated link prediction: embeddings are returned for the source entity column (as defined in the
LinkTask) - Link prediction: embeddings are returned for both source and target entity columns (as defined in the
LinkTask)
Embedding table naming convention
- Source entity embeddings:
DATABASE_NAME.PUBLIC.EMBEDDINGS_SRC_COL_NAME_ALIAS- Target entity embeddings:
DATABASE_NAME.PUBLIC.EMBEDDINGS_TGT_COL_NAME_ALIASWhere:
SRC_COL_NAME/TGT_COL_NAME: source/target entity column names from the taskALIASis theoutput_aliasspecified as an argument in thepredict()
Running Inference With materialize_results = False
Section titled “Running Inference With materialize_results = False”The default value of the materialize_results argument is True, meaning the method waits for the job to finish and then creates permanent tables. This makes the method run synchronously (blocking) instead of executing in the background. You can set materialize_results to False when calling predict() and later run the materialize_results() method to export the generated predictions (and embeddings).
# Non-blocking jobinference_job = trainer.predict( output_alias="EXPERIMENT_1", output_config=output_config, test_batch_size=128, model_run_id=train_job.model_run_id, extract_embeddings=True, materialize_results=False)# Run later to export the generated predictionsinference_job.materialize_results()Running Inference With Registered Model
Section titled “Running Inference With Registered Model”You can register a trained model in the snowflake model registry by calling register_model() on a JobMonitor object or by using the ModelManager.
For example:
train_job.register_model( model_name="new_model", version_name="v3", database_name="database_name", schema_name="schema_name", comment="Production-ready model",)Successful registration returns a message like:
✅ Successfully registered model 'database_name.schema_name.new_model.v3' for job '01bf85fd-0106-97cd-1df9-87011969c10f'Notice that we also add a model version to the registered model name. Registration is only allowed if a valid model_run_id exists. Duplicate version names for a model are not allowed.
⚠️ Note: Ensure the native app has the necessary permissions to the specified database and schema.
-- grant access to resources needed for snowflake experiment tracking
GRANT USAGE ON DATABASE <DATABASE> TO APPLICATION RELATIONALAI;GRANT USAGE ON SCHEMA <DATABASE>.<SCHEMA> TO APPLICATION RELATIONALAI;GRANT CREATE MODEL ON SCHEMA <DATABASE>.<SCHEMA> TO APPLICATION RELATIONALAI;To perform inference with a registered model we simply need to reference it by its registered_model_key. As an example:
inference_job = trainer.predict( output_alias="REG_MODEL", output_config=output_config, extract_embeddings=True registered_model_key="database_name.schema_name.new_model.v3")Running Inference on a New Test Table
Section titled “Running Inference on a New Test Table”When a test_table is provided without a dataset, inference can be run using this new test table while keeping the original dataset for graph construction. The provided test_table replaces the original test table, and its columns are validated to ensure they include all required fields.
The test table must contain the target entity column and, if the task involves temporal data, the time column. The label column is optional.
This is particularly useful when you want to run inference on a new test table without specifying a full dataset.
Example:
job = trainer.predict( model_run_id=train_job.model_run_id, test_table="PROD.DATA.FEBRUARY_CUSTOMERS", output_alias="EXPERIMENT_1", output_config=output_config)Running Inference Using a New Dataset and an Inference-Only Task
Section titled “Running Inference Using a New Dataset and an Inference-Only Task”When a dataset is provided without a test_table, inference uses the new dataset to construct graph and perform inference on its test table. This allows you to create a dataset with the same schema as the original but with different table data sources, for example, when data sources have changed. The dataset is validated to ensure that its schema matches the training dataset, and data types can optionally be aligned automatically if align_dtypes=True.
Example:
# Create inference task
inference_churn_task = NodeTask( connector=connector, name="my_inference_task", task_data_source={ "test": "DATABASE.SCHEMA.TEST" }, target_entity_column=ForeignKey( column_name="Id", link_to="Customer.Id" ), task_type=TaskType.BINARY_CLASSIFICATION)
# Create updated datasetupdated_dataset = Dataset( connector=connector, dataset_name="customer_churn_q1_2026", tables=[customers_table, transactions_table], task_description=inference_churn_task)
job = trainer.predict( model_run_id=train_job.model_run_id, dataset=updated_dataset, align_dtypes=True, # Handle any dtype differences if desired output_alias="EXPERIMENT_1", output_config=output_config)The dataset must have a schema that matches the training dataset, including tables, columns, keys, and relationships, and it must include a test table. If align_dtypes=False, column data types must also match the training dataset.
Validation occurs in two steps. First, schema validation checks that the dataset structure is compatible with the training dataset. Second, dtype validation ensures that column data types match, unless align_dtypes=True. When align_dtypes=True, the column data types of the provided dataset are automatically converted to match those of the original dataset. This helps prevent errors caused by the stochastic nature of automatic dtype inference when constructing GNNTable objects.
Common Validation Errors
Section titled “Common Validation Errors”- “Specify only one of
datasetortest_tableparameter. Not both.” — You provided bothdatasetandtest_tabletopredict(). Choose one or the other. - “This task can only be used for inference with trainer.predict()” — You passed an inference-only task (with only a
"test"split) totrainer.fit(). Create a task with all three splits for training. - “Test table validation failed: Missing required columns: […]” — Your custom test table is missing required columns. Ensure it has the entity column(s) and time column (if applicable).
- Schema mismatch errors — When providing a new
dataset, its schema must match the training dataset. For dtype-only differences, usealign_dtypes=True.