seisnn.model.GAN_trainer.GeneratorTrainer

class seisnn.model.GAN_trainer.GeneratorTrainer(database=None, model=<tensorflow.python.keras.engine.functional.Functional object>, optimizer=<tensorflow.python.keras.optimizer_v2.adam.Adam object>, loss=<tensorflow.python.keras.losses.BinaryCrossentropy object>)[source]

Bases: seisnn.model.GAN_trainer.BaseTrainer

Trainer class.

__init__(database=None, model=<tensorflow.python.keras.engine.functional.Functional object>, optimizer=<tensorflow.python.keras.optimizer_v2.adam.Adam object>, loss=<tensorflow.python.keras.losses.BinaryCrossentropy object>)[source]

Initialize the trainer.

Parameters
  • database – SQL database.

  • model – keras model.

  • optimizer – keras optimizer.

  • loss – keras loss.

Methods

__init__([database, model, optimizer, loss])

Initialize the trainer.

get_dataset_length([database, tfr_list])

get_model_dir(model_instance[, remove])

train_loop(tfr_list, model_name[, epochs, …])

Main training loop.

train_step(train, val)

Training step.

train_loop(tfr_list, model_name, epochs=1, batch_size=1, log_step=100, plot=False, remove=False)[source]

Main training loop.

Parameters
  • tfr_list – List of TFRecord path.

  • model_name (str) – Model directory name.

  • epochs (int) – Epoch number.

  • batch_size (int) – Batch size.

  • log_step (int) – Logging step interval.

  • plot (bool) – Plot training validation, False save fig, True show fig.

  • remove (bool) – If True, remove model folder before training.

Returns

train_step(train, val)[source]

Training step.

Parameters
  • train – Training data.

  • val – Validation data.

Return type

float

Returns

predict loss, validation loss