module great
class GReaT
GReaT Class
The GReaT class handles the whole generation flow. It is used to fine-tune a large language model for tabular data, and to sample synthetic tabular data.
Attributes:
llm(str): HuggingFace checkpoint of a pretrained large language model, used as basis of our modeltokenizer(AutoTokenizer): Tokenizer, automatically downloaded from llm-checkpointmodel(AutoModelForCausalLM): Large language model, automatically downloaded from llm-checkpointexperiment_dir(str): Directory, where the training checkpoints will be savedepochs(int): Number of epochs to fine-tune the modelbatch_size(int): Batch size used for fine-tuningefficient_finetuning(str): Fine-tuning method. Set to"lora"for LoRA fine-tuning.float_precision(int | None): Number of decimal places for floating point values. None means full precision.train_hyperparameters(dict): Additional hyperparameters added to the TrainingArguments used by the HuggingFace Library, see here the full list of all possible valuescolumns(list): List of all features/columns of the tabular datasetnum_cols(list): List of all numerical features/columns of the tabular datasetconditional_col(str): Name of a feature/column on which the sampling can be conditionedconditional_col_dist(dict | list): Distribution of the feature/column specified by conditional_col
method GReaT.__init__
__init__(
llm: str,
experiment_dir: str = 'trainer_great',
epochs: int = 100,
batch_size: int = 8,
efficient_finetuning: str = '',
lora_config: Optional[Dict[str, Any]] = None,
float_precision: Optional[int] = None,
report_to: List[str] = [],
**train_kwargs
)
Initializes GReaT.
Args:
llm: HuggingFace checkpoint of a pretrained large language model, used as basis for our modelexperiment_dir: Directory, where the training checkpoints will be savedepochs: Number of epochs to fine-tune the modelbatch_size: Batch size used for fine-tuningefficient_finetuning: Fine-tuning method. Set to"lora"to enable LoRA (Low-Rank Adaptation) fine-tuning. Requires thepeftpackage.lora_config: Optional dictionary of LoRA hyperparameters to override defaults. Supported keys:r(rank, default 16),lora_alpha(scaling factor, default 32),target_modules(list of module names or None for auto-detection),lora_dropout(default 0.05),bias(default "none"),task_type(default "CAUSAL_LM"),modules_to_save(default None).float_precision: Number of decimal places to use for floating point numbers. If None, full precision is used.report_to: List of integrations to report to (e.g.["wandb"]). Empty list disables reporting.train_kwargs: Additional hyperparameters added to the TrainingArguments used by the HuggingFace Library, see here the full list of all possible values
method GReaT.fit
fit(
data: Union[DataFrame, ndarray],
column_names: Optional[List[str]] = None,
conditional_col: Optional[str] = None,
resume_from_checkpoint: Union[bool, str] = False,
random_conditional_col: bool = True
) → GReaTTrainer
Fine-tune GReaT using tabular data.
Args:
data: Pandas DataFrame or Numpy Array that contains the tabular datacolumn_names: If data is Numpy Array, the feature names have to be defined. If data is Pandas DataFrame, the value is ignoredconditional_col: If given, the distribution of this column is saved and used as a starting point for the generation process later. If None, the last column is considered as conditional featureresume_from_checkpoint: If True, resumes training from the latest checkpoint in the experiment_dir. If path, resumes the training from the given checkpoint (has to be a valid HuggingFace checkpoint!)random_conditional_col: If True, a different random column is selected for conditioning at the end of each training epoch. This prevents overfitting on a single column and leads to more balanced synthetic data.
Returns: GReaTTrainer used for the fine-tuning process
method GReaT.sample
sample(
n_samples: int,
start_col: Optional[str] = '',
start_col_dist: Optional[Union[dict, list]] = None,
temperature: float = 0.7,
k: int = 100,
max_length: int = 100,
drop_nan: bool = False,
device: str = 'cuda',
guided_sampling: bool = False,
random_feature_order: bool = True
) → DataFrame
Generate synthetic tabular data samples.
Args:
n_samples: Number of synthetic samples to generatestart_col: Feature to use as starting point for the generation process. If not given, the target learned during the fitting is used as starting pointstart_col_dist: Feature distribution of the starting feature. Should have the format{"F1": p1, "F2": p2, ...}for discrete columns or be a list of possible values for continuous columns. If not given, the target distribution learned during the fitting is used as starting pointtemperature: The generation samples each token from the probability distribution given by a softmax function. The temperature parameter controls the softmax function. A low temperature makes it sharper (0 equals greedy search), a high temperature brings more diversity but also uncertainty into the output. See this blog article to read more about the generation process.k: Sampling Batch Size. Set as high as possible. Speeds up the generation process significantlymax_length: Maximal number of tokens to generate - has to be long enough to not cut any information!drop_nan: If True, rows with any NaN values are dropped from the generated outputdevice: Set to"cpu"if the GPU should not be used. You can also specify the concrete GPU (e.g."cuda:0")guided_sampling: If True, enables feature-by-feature guided generation. This is slower but can produce more reliable results for datasets with many features or complex relationships.random_feature_order: If True (andguided_sampling=True), the order of feature generation is randomized for each sample. Helps avoid ordering bias.
Returns: Pandas DataFrame with n_samples rows of generated data
method GReaT.great_sample
great_sample(
starting_prompts: Union[str, list[str]],
temperature: float = 0.7,
max_length: int = 100,
device: str = 'cuda'
) → DataFrame
Generate synthetic tabular data samples conditioned on a given input.
Args:
starting_prompts: String or List of Strings on which the output is conditioned. For example,"Sex is female, Age is 26"temperature: The generation samples each token from the probability distribution given by a softmax function. The temperature parameter controls the softmax function. A low temperature makes it sharper (0 equals greedy search), a high temperature brings more diversity but also uncertainty into the output. See this blog article to read more about the generation process.max_length: Maximal number of tokens to generate - has to be long enough to not cut any informationdevice: Set to"cpu"if the GPU should not be used. You can also specify the concrete GPU.
Returns: Pandas DataFrame with synthetic data generated based on starting_prompts
method GReaT.impute
impute(
df_miss: DataFrame,
temperature: float = 0.7,
k: int = 100,
max_length: int = 100,
max_retries: int = 15,
device: str = 'cuda'
) → DataFrame
Impute a DataFrame with missing values using a trained GReaT model.
Args:
df_miss: Pandas DataFrame of the exact same format (column names, value ranges/types) as the data used to train the GReaT model, with missing values indicated by NaN. This function will sample the missing values conditioned on the remaining values.temperature: Controls the softmax function during generation. Lower values produce more deterministic output.k: Sampling batch sizemax_length: Maximal number of tokens to generatemax_retries: Maximum number of retries if imputation fails to fill all valuesdevice: Set to"cpu"if the GPU should not be used
Returns: Pandas DataFrame with imputed values
method GReaT.save
save(path: str)
Save GReaT Model
Saves the model weights and a configuration file in the given directory. If LoRA fine-tuning was used, saves the adapter weights separately using PEFT's native save_pretrained method so they can be reloaded efficiently. Supports remote file systems via fsspec (e.g. s3://, gs://).
Args:
path: Path where to save the model
classmethod GReaT.load_from_dir
load_from_dir(path: str)
Load GReaT class
Load trained GReaT model from directory. Automatically detects whether the model was saved with LoRA adapters or as a full checkpoint. Supports remote file systems via fsspec.
Args:
path: Directory where GReaT model is saved
Returns: New instance of GReaT loaded from directory
This file was manually updated to match the current source code.