Skip to content

Trainer.predict()

Submits an inference job using a specified model. Jobs initiated by the predict() method are of type inference and can be monitored using the returned JobMonitor object.

Parameters Requirements:

Input Source:

Inference supports three input scenarios:

  • dataset: A new dataset with the same schema used during training is provided. The tables in the provided dataset are used to construct the graph, and inference is performed on the dataset’s test table. Therefore, the provided dataset must include a test table.
  • test_table: The dataset the model was trained on is used to construct the graph, and inference is performed on the provided test_table.
  • Neither dataset nor test_table provided: The dataset the model was trained on is used, and inference is performed on its test table.

Test Table Behavior:

The test_table argument is only valid when no dataset is provided. When specified, it overrides the default test table defined in the dataset the model was trained on. This is useful for running inference on a new test table source without creating a new Dataset object.

Model Selection:

A model must be specified using one of the following options:

  • registered_model_key: Use a specific registered model. The convention followed for the name is database_name.schema_name.registered_model_name.version, where database_name is the name of the database that the model is saved, schema_name is the name of the schema in the database that the model is saved, registered_model_name is the name of the registered model and version is the version of the registered model.
  • model_run_id: Use a specific model run ID.
NameTypeDescriptionOptional
output_aliasstrAn alias to append to the result tables of predictions and embeddings. This allows you to differentiate between multiple inference runs.No
output_configOutputConfigConfiguration specifying where to save results.No
materialize_resultsboolIf materialize_results is True, the method waits for the job to finish and then creates permanent tables. This makes the method run synchronously (blocking) instead of executing in the background. Defaults to True.Yes
test_batch_sizeintTest batch size to use during inference. Defaults to 128.Yes
datasetDatasetDataset to run inference on.Yes
test_tablestrFully qualified path to a test table. Uses the training graph with this table replacing the original test split (if existed). Must contain the required entity columns (and time column if used during training). The label column is optional. Cannot be used together with dataset.Yes
model_run_idstrRun inference using this specific model run ID. You can obtain the model run id from your training job using train_job.model_run_id. Either a model_run_id or a registered_model_key must be provided.Yes
registered_model_keystrThe full identifier of a registered model in the format: database_name.schema_name.registered_model_name.version. Either a model_run_id or a registered_model_key must be provided.Yes
extract_embeddingsboolIf True, extract node embeddings. For node tasks and repeated link prediction tasks, source node embeddings are extracted. For link prediction tasks, both source and target node embeddings are extracted. Defaults to False.Yes
align_dtypesboolControls how column data type mismatches are handled: - False: Raises a ValidationError if column data types in the provided dataset or test_table do not match those of the training dataset. - True: Automatically aligns column dtypes in the provided dataset or test_table with the model’s expected types. Defaults to False.Yes

An instance of a JobMonitor object, to monitor the submitted job, check its status and retrieve results after inference.

Before running inference, you must specify where the output data will be saved. The data is stored in Snowflake tables. Ensure that the engine has write permissions for the target database and schema.

from relationalai_gnns import OutputConfig
output_config = OutputConfig.snowflake(
database_name="DATABASE_NAME",
schema_name="PUBLIC"
)

We provide several ways to run inference. In the simplest scenario, you can use a model that has just been trained. Each trained model has a unique identifier: model_run_id, provided by Snowflake experiment tracking. After training completes (you can verify this via the job’s status), you can access the model’s run ID with train_job.model_run_id, where train_job is an instance of a JobMonitor object.

Example of running inference using a trained model:

inference_job = trainer.predict(
output_alias="EXPERIMENT_1",
output_config=output_config,
test_batch_size=128,
model_run_id=train_job.model_run_id,
extract_embeddings=True
)

Note: In the example above inference is run on the original dataset that the model was trained on since no new dataset or train_table arguments are specified.

By default, the engine will write the predictions to a table named PREDICTIONS_{output_alias} in the schema defined in the OutputConfig. In this example the predictions will be stored in a table named:

DATABASE_NAME.PUBLIC.PREDICTIONS_EXPERIMENT_1

The output_alias (e.g., EXPERIMENT_1) helps differentiate results from multiple inference jobs. The application is not permitted to overwrite existing tables. If a table with the same alias already exists, an error is raised.

Embedding Extraction

When performing inference, you can also extract embeddings by setting extract_embeddings=True. For:

  • Node classification: embeddings are returned for the target entity column (as defined in the NodeTask)
  • Repeated link prediction: embeddings are returned for the source entity column (as defined in the LinkTask)
  • Link prediction: embeddings are returned for both source and target entity columns (as defined in the LinkTask)

Embedding table naming convention

  • Source entity embeddings:
DATABASE_NAME.PUBLIC.EMBEDDINGS_SRC_COL_NAME_ALIAS
  • Target entity embeddings:
DATABASE_NAME.PUBLIC.EMBEDDINGS_TGT_COL_NAME_ALIAS

Where:

  • SRC_COL_NAME/ TGT_COL_NAME: source/target entity column names from the task
  • ALIAS is the output_alias specified as an argument in the predict()

Running Inference With materialize_results = False

Section titled “Running Inference With materialize_results = False”

The default value of the materialize_results argument is True, meaning the method waits for the job to finish and then creates permanent tables. This makes the method run synchronously (blocking) instead of executing in the background. You can set materialize_results to False when calling predict() and later run the materialize_results() method to export the generated predictions (and embeddings).

# Non-blocking job
inference_job = trainer.predict(
output_alias="EXPERIMENT_1",
output_config=output_config,
test_batch_size=128,
model_run_id=train_job.model_run_id,
extract_embeddings=True,
materialize_results=False
)
# Run later to export the generated predictions
inference_job.materialize_results()

You can register a trained model in the snowflake model registry by calling register_model() on a JobMonitor object or by using the ModelManager. For example:

train_job.register_model(
model_name="new_model",
version_name="v3",
database_name="database_name",
schema_name="schema_name",
comment="Production-ready model",
)

Successful registration returns a message like:

✅ Successfully registered model 'database_name.schema_name.new_model.v3' for job '01bf85fd-0106-97cd-1df9-87011969c10f'

Notice that we also add a model version to the registered model name. Registration is only allowed if a valid model_run_id exists. Duplicate version names for a model are not allowed.

⚠️ Note: Ensure the native app has the necessary permissions to the specified database and schema.

-- grant access to resources needed for snowflake experiment tracking
GRANT USAGE ON DATABASE <DATABASE> TO APPLICATION RELATIONALAI;
GRANT USAGE ON SCHEMA <DATABASE>.<SCHEMA> TO APPLICATION RELATIONALAI;
GRANT CREATE MODEL ON SCHEMA <DATABASE>.<SCHEMA> TO APPLICATION RELATIONALAI;

To perform inference with a registered model we simply need to reference it by its registered_model_key. As an example:

inference_job = trainer.predict(
output_alias="REG_MODEL",
output_config=output_config,
extract_embeddings=True
registered_model_key="database_name.schema_name.new_model.v3"
)

When a test_table is provided without a dataset, inference can be run using this new test table while keeping the original dataset for graph construction. The provided test_table replaces the original test table, and its columns are validated to ensure they include all required fields.

The test table must contain the target entity column and, if the task involves temporal data, the time column. The label column is optional.

This is particularly useful when you want to run inference on a new test table without specifying a full dataset.

Example:

job = trainer.predict(
model_run_id=train_job.model_run_id,
test_table="PROD.DATA.FEBRUARY_CUSTOMERS",
output_alias="EXPERIMENT_1",
output_config=output_config
)

Running Inference Using a New Dataset and an Inference-Only Task

Section titled “Running Inference Using a New Dataset and an Inference-Only Task”

When a dataset is provided without a test_table, inference uses the new dataset to construct graph and perform inference on its test table. This allows you to create a dataset with the same schema as the original but with different table data sources, for example, when data sources have changed. The dataset is validated to ensure that its schema matches the training dataset, and data types can optionally be aligned automatically if align_dtypes=True.

Example:

# Create inference task
inference_churn_task = NodeTask(
connector=connector,
name="my_inference_task",
task_data_source={
"test": "DATABASE.SCHEMA.TEST"
},
target_entity_column=ForeignKey(
column_name="Id",
link_to="Customer.Id"
),
task_type=TaskType.BINARY_CLASSIFICATION
)
# Create updated dataset
updated_dataset = Dataset(
connector=connector,
dataset_name="customer_churn_q1_2026",
tables=[customers_table, transactions_table],
task_description=inference_churn_task
)
job = trainer.predict(
model_run_id=train_job.model_run_id,
dataset=updated_dataset,
align_dtypes=True, # Handle any dtype differences if desired
output_alias="EXPERIMENT_1",
output_config=output_config
)

The dataset must have a schema that matches the training dataset, including tables, columns, keys, and relationships, and it must include a test table. If align_dtypes=False, column data types must also match the training dataset.

Validation occurs in two steps. First, schema validation checks that the dataset structure is compatible with the training dataset. Second, dtype validation ensures that column data types match, unless align_dtypes=True. When align_dtypes=True, the column data types of the provided dataset are automatically converted to match those of the original dataset. This helps prevent errors caused by the stochastic nature of automatic dtype inference when constructing GNNTable objects.

  • “Specify only one of dataset or test_table parameter. Not both.” — You provided both dataset and test_table to predict(). Choose one or the other.
  • “This task can only be used for inference with trainer.predict()” — You passed an inference-only task (with only a "test" split) to trainer.fit(). Create a task with all three splits for training.
  • “Test table validation failed: Missing required columns: […]” — Your custom test table is missing required columns. Ensure it has the entity column(s) and time column (if applicable).
  • Schema mismatch errors — When providing a new dataset, its schema must match the training dataset. For dtype-only differences, use align_dtypes=True.