Defines training parameters.
Parameter Type Description Optional connectorSnowflakeConnectorThe connector object used for sending requests to the GNN engine. No experiment_configExperimentConfigAn instance of ExperimentConfig which defines the database and schema that experiment metadata will be saved. No devicestrDevice to perform training, inference, and feature extraction. One of "cuda" or "cpu". No seedintRandom seed for reproducibility. Default: 42. Yes
Parameter Type Description Optional n_epochsintNumber of training epochs. An epoch corresponds to a full pass over the training data. No max_itersintMaximum number of batch iterations per epoch. If None, all batches are processed. Otherwise, limits iterations. Default: None. Yes train_batch_sizeintBatch size for training. Default: 128. Yes val_batch_sizeintBatch size for validation. Default: 128. Yes eval_everyintFrequency (in epochs) to evaluate on the validation set. Default: 1. Yes patienceintNumber of epochs without improvement before early stopping. Default: 5. Yes lrfloatLearning rate. Default: 0.001. Yes T_maxintMax iterations for cosine annealing scheduler. Defaults to n_epochs if None. Default: None. Yes eta_minintMinimum learning rate for cosine annealing. Default: 1e-8. Yes
Parameter Type Description Optional label_smoothingboolWhether to apply label smoothing (for classification). Default: False. Yes label_smoothing_alphafloatSmoothing parameter α ∈ (0, 1). Default: 0.1. Yes clamp_minintSpecifies the lower bound of the model’s output distribution in percentile terms (0–100). A value of 0 means no lower percentile cutoff is applied, while higher values restrict predictions to exclude the lowest portion of the output distribution. Default: 0. Yes clamp_maxintSpecifies the higher bound of the model’s output distribution in percentile terms (0–100). A value of 100 means no higher percentile cutoff is applied, while lower values restrict predictions to exclude the highest portion of the output distribution. Default: 100. Yes
Parameter Type Description Optional channelsintHidden channels for GNN, encoders, and prediction heads. Default: 128. Yes gnn_layersintNumber of GNN layers. Defaults to len(fanouts) if None. Default: None. Yes fanoutsList[int]Neighbors to sample per GNN layer. E.g., [128, 64]. Default: [128, 64]. Yes conv_aggregationstrAggregation method for convolutions. It can be one of “mean”, “max” or “sum”. Default: "mean". Yes hetero_conv_aggregationstrAggregation across edge types in heterogeneous graphs. It can be one of “mean”, “max” or “sum”. Default: "sum". Yes gnn_normstrNormalization for GNN layers. It can be one of “batch_norm”, “layer_norm” or “instance_norm”. Default: "layer_norm". Yes
Parameter Type Description Optional head_layersintNumber of MLP layers in the prediction head. Default: 1. Yes head_normstrNormalization for the MLP prediction head. It can be one of “batch_norm” or “layer_norm”. Default: "batch_norm". Yes
Parameter Type Description Optional use_temporal_encoderboolWhether to use a temporal encoding model. Default: True. Yes temporal_strategystrStrategy for temporal neighbor sampling. "uniform" ignores time; "last" picks most recent. Default: "uniform". Yes
Parameter Type Description Optional num_negativeintNumber of negative samples per source node (for link prediction). Default: 10. Yes negative_sampling_strategystrStrategy: "random" or "degree_based". "degree_based" favors popular nodes. Default: "random". Yes
Parameter Type Description Optional text_embedderstrText embedding model. It can be one of “model2vec-potion-base-4M” or “bert-base-distill” Default: "model2vec-potion-base-4M". Yes id_awarenessboolWhether to use ID-awareness embeddings. Default: False. Yes shallow_embeddings_listList[str]Tables to assign learnable shallow embeddings. Default: []. Yes
An instance of the TrainerConfig class.
from relationalai_gnns import ExperimentConfig, TrainerConfig
experiment_config = ExperimentConfig( database = "database_name" ,
trainer_config = TrainerConfig(
experiment_config =experiment_config,
Name Description Returns to_dictWrites TrainerConfig contents in a dictionary DictvalidateValidates the model’s configuration bool
Writes TrainerConfig contents in a dictionary.
model_config_dict = config.to_dict()
Validates the model configuration. Returns True if the configuration is valid; otherwise, returns False along with a list of validation errors.