seisnn.model.GAN_trainer.GeneratorTrainer¶
- class seisnn.model.GAN_trainer.GeneratorTrainer(database=None, model=<tensorflow.python.keras.engine.functional.Functional object>, optimizer=<tensorflow.python.keras.optimizer_v2.adam.Adam object>, loss=<tensorflow.python.keras.losses.BinaryCrossentropy object>)[source]¶
Bases:
seisnn.model.GAN_trainer.BaseTrainer
Trainer class.
- __init__(database=None, model=<tensorflow.python.keras.engine.functional.Functional object>, optimizer=<tensorflow.python.keras.optimizer_v2.adam.Adam object>, loss=<tensorflow.python.keras.losses.BinaryCrossentropy object>)[source]¶
Initialize the trainer.
- Parameters
database – SQL database.
model – keras model.
optimizer – keras optimizer.
loss – keras loss.
Methods
__init__
([database, model, optimizer, loss])Initialize the trainer.
get_dataset_length
([database, tfr_list])get_model_dir
(model_instance[, remove])train_loop
(tfr_list, model_name[, epochs, …])Main training loop.
train_step
(train, val)Training step.
- train_loop(tfr_list, model_name, epochs=1, batch_size=1, log_step=100, plot=False, remove=False)[source]¶
Main training loop.
- Parameters
tfr_list – List of TFRecord path.
model_name (str) – Model directory name.
epochs (int) – Epoch number.
batch_size (int) – Batch size.
log_step (int) – Logging step interval.
plot (bool) – Plot training validation, False save fig, True show fig.
remove (bool) – If True, remove model folder before training.
- Returns