Skip to content

TrainerConfig

Defines training parameters.

ParameterTypeDescriptionOptional
connectorSnowflakeConnectorThe connector object used for sending requests to the GNN engine.No
experiment_configExperimentConfigAn instance of ExperimentConfig which defines the database and schema that experiment metadata will be saved.No
devicestrDevice to perform training, inference, and feature extraction. One of "cuda" or "cpu".No
seedintRandom seed for reproducibility. Default: 42.Yes
ParameterTypeDescriptionOptional
n_epochsintNumber of training epochs. An epoch corresponds to a full pass over the training data.No
max_itersintMaximum number of batch iterations per epoch. If None, all batches are processed. Otherwise, limits iterations. Default: None.Yes
train_batch_sizeintBatch size for training. Default: 128.Yes
val_batch_sizeintBatch size for validation. Default: 128.Yes
eval_everyintFrequency (in epochs) to evaluate on the validation set. Default: 1.Yes
patienceintNumber of epochs without improvement before early stopping. Default: 5.Yes
lrfloatLearning rate. Default: 0.001.Yes
T_maxintMax iterations for cosine annealing scheduler. Defaults to n_epochs if None. Default: None.Yes
eta_minintMinimum learning rate for cosine annealing. Default: 1e-8.Yes
ParameterTypeDescriptionOptional
label_smoothingboolWhether to apply label smoothing (for classification). Default: False.Yes
label_smoothing_alphafloatSmoothing parameter α ∈ (0, 1). Default: 0.1.Yes
clamp_minintSpecifies the lower bound of the model’s output distribution in percentile terms (0–100). A value of 0 means no lower percentile cutoff is applied, while higher values restrict predictions to exclude the lowest portion of the output distribution. Default: 0.Yes
clamp_maxintSpecifies the higher bound of the model’s output distribution in percentile terms (0–100). A value of 100 means no higher percentile cutoff is applied, while lower values restrict predictions to exclude the highest portion of the output distribution. Default: 100.Yes
ParameterTypeDescriptionOptional
channelsintHidden channels for GNN, encoders, and prediction heads. Default: 128.Yes
gnn_layersintNumber of GNN layers. Defaults to len(fanouts) if None. Default: None.Yes
fanoutsList[int]Neighbors to sample per GNN layer. E.g., [128, 64]. Default: [128, 64].Yes
conv_aggregationstrAggregation method for convolutions. It can be one of “mean”, “max” or “sum”. Default: "mean".Yes
hetero_conv_aggregationstrAggregation across edge types in heterogeneous graphs. It can be one of “mean”, “max” or “sum”. Default: "sum".Yes
gnn_normstrNormalization for GNN layers. It can be one of “batch_norm”, “layer_norm” or “instance_norm”. Default: "layer_norm".Yes
ParameterTypeDescriptionOptional
head_layersintNumber of MLP layers in the prediction head. Default: 1.Yes
head_normstrNormalization for the MLP prediction head. It can be one of “batch_norm” or “layer_norm”. Default: "batch_norm".Yes
ParameterTypeDescriptionOptional
use_temporal_encoderboolWhether to use a temporal encoding model. Default: True.Yes
temporal_strategystrStrategy for temporal neighbor sampling. "uniform" ignores time; "last" picks most recent. Default: "uniform".Yes
ParameterTypeDescriptionOptional
num_negativeintNumber of negative samples per source node (for link prediction). Default: 10.Yes
negative_sampling_strategystrStrategy: "random" or "degree_based". "degree_based" favors popular nodes. Default: "random".Yes
ParameterTypeDescriptionOptional
text_embedderstrText embedding model. It can be one of “model2vec-potion-base-4M” or “bert-base-distill” Default: "model2vec-potion-base-4M".Yes
id_awarenessboolWhether to use ID-awareness embeddings. Default: False.Yes
shallow_embeddings_listList[str]Tables to assign learnable shallow embeddings. Default: [].Yes

An instance of the TrainerConfig class.

from relationalai_gnns import ExperimentConfig, TrainerConfig
experiment_config = ExperimentConfig(database="database_name",
schema="schema_name")
trainer_config = TrainerConfig(
connector=connector,
experiment_config=experiment_config,
device="cuda",
n_epochs=10,
patience=5)
NameDescriptionReturns
to_dictWrites TrainerConfig contents in a dictionaryDict
validateValidates the model’s configurationbool

Writes TrainerConfig contents in a dictionary.

model_config_dict = config.to_dict()

Validates the model configuration. Returns True if the configuration is valid; otherwise, returns False along with a list of validation errors.

config.validate()