GNN
GNN( *, exp_database: str, exp_schema: str, database: Optional[str] = None, schema: Optional[str] = None, graph: Optional[Graph] = None, property_transformer: Optional[PropertyTransformer] = None, train: Optional[b.Relationship | b.Fragment | b.Chain] = None, validation: Optional[b.Relationship | b.Fragment | b.Chain] = None, source_concept: Optional[b.Concept] = None, target_concept: Optional[b.Concept] = None, task_type: Optional[str] = None, eval_metric: Optional[str] = None, use_current_time: bool = True, has_time_column: Optional[bool] = None, model_database: Optional[str] = None, model_schema: Optional[str] = None, model_name: Optional[str] = None, version_name: Optional[str] = None, model_run_id: Optional[str] = None, test_batch_size: Optional[int] = None, stream_logs: bool = True, extract_embeddings: bool = False, dataset_alias: Optional[str] = None, parallel_reasoners_init: bool = True, **train_params: Any)Train, load, register and predict with a Graph Neural Network.
GNN supports two workflows:
- Fit workflow — provide
graph,train,validation, and atask_type, then callGNN.fitfollowed byGNN.predictions. - Load workflow — provide a previously trained model via
model_run_idor the four registry parameters (model_database,model_schema,model_name,version_name), then callGNN.loadfollowed byGNN.predictions.
In either workflow you can register the trained model in
the Snowflake Model Registry using the GNN.register_model method.
Parameters
Section titled “Parameters”
(exp_databasestr) - Snowflake database for experiment storage.
(exp_schemastr) - Snowflake schema for experiment storage.
(databasestr, default:None) - Snowflake database to save predictions in.
(schemastr, default:None) - Snowflake schema to save predictions in.
(graphGraph, default:None) - The knowledge graph with edges defined. Required for the fit workflow.
(property_transformerPropertyTransformer, default:None) - Column-level semantic type annotations. If omitted, all column types are auto-inferred.
(trainRelationship or Fragment, default:None) - Training split relationship. Required for the fit workflow.
(validationRelationship or Fragment, default:None) - Validation split relationship. Required for the fit workflow.
(source_conceptConcept, default:None) - Source concept for the load workflow (inferred fromtrainin the fit workflow).
(target_conceptConcept, default:None) - Target concept for link-prediction tasks in the load workflow (inferred fromtrainin the fit workflow).
(task_typestr, default:None) - One of"binary_classification","multiclass_classification","multilabel_classification","regression","link_prediction", or"repeated_link_prediction". Required for the fit workflow. In the load workflow, inferred automatically from the registered model if not provided.
(eval_metricstr, default:None) - Evaluation metric compatible with the chosentask_type(e.g."roc_auc","accuracy","rmse","link_prediction_precision@5").
(use_current_timebool, default:True) - Use the current timestamp as the prediction time. Default isTrue.
(has_time_columnbool, default:None) - Set toTruewhen the task relationships use theatkeyword for temporal ordering. In the load workflow, inferred automatically from the registered model if not provided.
(model_databasestr, default:None) - Snowflake database of a registered model (load workflow).
(model_schemastr, default:None) - Snowflake schema of a registered model (load workflow).
(model_namestr, default:None) - Name of the registered model (load workflow).
(version_namestr, default:None) - Version of the registered model (load workflow).
(model_run_idstr, default:None) - Run ID of a previously trained model (load workflow).
(test_batch_sizeint, default:None) - Batch size used during inference.
(stream_logsbool, default:True) - Stream training logs to stdout. Default isTrue.
(extract_embeddingsbool, default:False) - Extract node embeddings during prediction. Default isFalse.
(dataset_aliasstr, default:None) - User chosen alias for the dataset.
(parallel_reasoners_initbool, default:True) - Initialize the Predictive and Logic reasoners in parallel. Default isTrue.
(**train_paramsAny, default:{}) - Additional hyperparameters forwarded to the GNN trainer (e.g.n_epochs,lr,train_batch_size,device,head_layers). Ignored in the load workflow.
Examples
Section titled “Examples”Assuming the setup from the module-level Quick Start
(relationalai.semantics.reasoners.predictive):
gnn = GNN( exp_database="EXPERIMENTS_DB", exp_schema="EXPERIMENTS_SCHEMA", graph=gnn_graph, property_transformer=property_transformer, source_concept=Students, train=Train, validation=Validation, task_type="binary_classification", eval_metric="roc_auc", device="cuda", n_epochs=5,)gnn.fit()Students.predictions = gnn.predictions(domain=Test)Load a registered model and predict:
gnn = GNN( exp_database="EXPERIMENTS_DB", exp_schema="EXPERIMENTS_SCHEMA", model_database="MODELS_DB", model_schema="MODELS_SCHEMA", model_name="MY_MODEL", version_name="V1", source_concept=Students,)gnn.load()Students.predictions = gnn.predictions(domain=Test)Methods
Section titled “Methods”.fit()
Section titled “.fit()”GNN.fit() -> GNNTrain the GNN model.
Materializes graph and task tables, creates a trainer, and submits a
training job. If training has already completed or is in progress,
calling fit() again is a no-op.
Returns:
GNN- This instance, so that calls can be chained (e.g.gnn.fit().predictions(domain=Test)).
Raises:
ValueError- If called in a load workflow.
.load()
Section titled “.load()”GNN.load() -> GNNLoad a previously trained model for prediction.
Resolves the model from the Snowflake Model Registry (when
model_database, model_schema, model_name, and
version_name were provided) or by model_run_id, and
prepares the trainer for prediction. If the model is already
loaded, calling load() again is a no-op.
Returns:
GNN- This instance, so that calls can be chained (e.g.gnn.load().predictions(domain=Test)).
Raises:
ValueError- If called in a fit workflow, or if the specified model or version does not exist.
.predictions()
Section titled “.predictions()”GNN.predictions(domain: b.Relationship | b.Fragment | b.Chain) -> b.RelationshipGenerate predictions on a test domain.
Materializes the test table, submits a prediction job, and returns a
Relationship that can be assigned to a
concept field for downstream querying.
The prediction attributes available on the returned relationship depend on the task type:
- Classification:
.probs,.predicted_labels - Regression:
.predicted_value - Link prediction:
.rank,.scores,.predicted_<target>
Parameters:
(domainRelationship or Fragment or Chain) - The test split relationship (e.g. theTestrelationship defined during data modeling).
Returns:
Relationship- A prediction relationship to be assigned to the source concept (e.g.User.predictions = gnn.predictions(domain=Test)).
Raises:
TypeError- Ifdomainis not a Relationship, Fragment, or Chain.ValueError- If the model has not been fitted or loaded, or if the test domain schema does not match the training schema.
.register_model()
Section titled “.register_model()”GNN.register_model( model_database: str, model_schema: str, model_name: str, version_name: str, *, comment: Optional[str] = None) -> NoneRegister a trained model in the Snowflake Model Registry.
After registration the model can be loaded in a later session by
passing the same model_database, model_schema,
model_name, and version_name to the GNN constructor.
Parameters:
(model_databasestr) - Snowflake database for the model registry entry.
(model_schemastr) - Snowflake schema for the model registry entry.
(model_namestr) - Name under which to register the model.
(version_namestr) - Version label (e.g."v1").
(commentstr, default:None) - Free-text comment stored alongside the registry entry.
Raises:
ValueError- If the model is already registered, or has not been fitted / loaded.
Examples:
gnn.fit()gnn.register_model( model_database="MODELS_DB", model_schema="MODELS_SCHEMA", model_name="STUDENT_CHURN", version_name="V1", comment="First baseline model",).visualize_dataset()
Section titled “.visualize_dataset()”GNN.visualize_dataset(show_dtypes: bool = False)Visualize the dataset graph.
Returns a graph object that can be rendered in a notebook to inspect the node and edge structure of the prepared dataset.
Parameters:
(show_dtypesbool, default:False) - Include column data types in the visualization. Default isFalse.
Returns:
object- A graph visualization object that can be rendered in a notebook.
Raises:
ValueError- If no dataset has been prepared yet (i.e.GNN.fithas not been called).
Examples:
from IPython.display import Image, display
gnn.fit()graph = gnn.visualize_dataset(show_dtypes=True)display(Image(graph.create_png()))Returned By
Section titled “Returned By”semantics > reasoners > predictive > estimator └── GNN ├── fit └── load