Source code for seisnn.model.GAN_trainer

import os
import shutil

import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import concatenate

from seisnn.core import Instance
from seisnn.model.generator import nest_net
from seisnn.model.attention import transformer
from seisnn.model.GAN_model import build_discriminator, build_cgan, build_patch_discriminator
import seisnn.example_proto
import seisnn.io
import seisnn.sql
import seisnn.utils

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)


[docs]class BaseTrainer: @staticmethod def get_dataset_length(database=None, tfr_list=None): count = None try: db = seisnn.sql.Client(database) tfr_list = seisnn.utils.flatten_list(tfr_list) counts = db.get_tfrecord(path=tfr_list, column='count') count = sum(seisnn.utils.flatten_list(counts)) except Exception as error: print(f'{type(error).__name__}: {error}') return count @staticmethod def get_model_dir(model_instance, remove=False): config = seisnn.utils.Config() save_model_path = os.path.join(config.models, model_instance) if remove: shutil.rmtree(save_model_path, ignore_errors=True) seisnn.utils.make_dirs(save_model_path) save_history_path = os.path.join(save_model_path, "history") seisnn.utils.make_dirs(save_history_path) return save_model_path, save_history_path
[docs]class GeneratorTrainer(BaseTrainer): """ Trainer class. """
[docs] def __init__(self, database=None, model=nest_net(), optimizer=tf.keras.optimizers.Adam(1e-4), loss=tf.keras.losses.BinaryCrossentropy()): """ Initialize the trainer. :param database: SQL database. :param model: keras model. :param optimizer: keras optimizer. :param loss: keras loss. """ self.generator_model = transformer(img_rows=1, img_cols=3008, color_type=3, num_class=3) self.discriminator_model = build_discriminator(img_rows=1, img_cols=3008, color_type=3, num_class=3) self.generator_optimizer = Adam(learning_rate=1e-3) self.generator_model.compile(loss='binary_crossentropy', optimizer=self.generator_optimizer) self.discriminator_model.trainable = False self.cgan_model = build_cgan(self.generator_model, self.discriminator_model) loss = [tf.keras.losses.BinaryCrossentropy(), tf.keras.losses.BinaryCrossentropy()] loss_weights = [100000, 1] self.cgan_optimizer = Adam(learning_rate=1e-3) self.cgan_model.compile(loss=loss, loss_weights=loss_weights, optimizer=self.cgan_optimizer) self.discriminator_optimizer = Adam(learning_rate=1e-3) self.discriminator_model.trainable = True self.discriminator_model.compile(loss='binary_crossentropy', optimizer=self.discriminator_optimizer) self.database = database self.model = self.cgan_model self.optimizer = optimizer
[docs] def train_loop(self, tfr_list, model_name, epochs=1, batch_size=1, log_step=100, plot=False, remove=False): """ Main training loop. :param tfr_list: List of TFRecord path. :param str model_name: Model directory name. :param int epochs: Epoch number. :param int batch_size: Batch size. :param int log_step: Logging step interval. :param bool plot: Plot training validation, False save fig, True show fig. :param bool remove: If True, remove model folder before training. :return: """ model_path, history_path = self.get_model_dir(model_name, remove=remove) ckpt = tf.train.Checkpoint( generator_model=self.generator_model, discriminator_model=self.discriminator_model, cgan_model=self.cgan_model, generator_optimizer=self.generator_optimizer, discriminator_optimizer=self.discriminator_optimizer, cgan_optimizer=self.cgan_optimizer, ) ckpt_manager = tf.train.CheckpointManager(ckpt, model_path, max_to_keep=100) if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) last_epoch = len(ckpt_manager.checkpoints) print(f'Latest checkpoint epoch {last_epoch} restored!!') dataset = seisnn.io.read_dataset(tfr_list) dataset = dataset.shuffle(5000) val = next(iter(dataset.batch(1))) metrics_names = ['loss', 'val'] # data_len = self.get_dataset_length(self.database,tfr_list) for epoch in range(epochs): print(f'epoch {epoch + 1} / {epochs}') n = 0 progbar = tf.keras.utils.Progbar( 190000, stateful_metrics=metrics_names) for train in dataset.prefetch(100).batch(batch_size): d_loss, g_loss = self.train_step(train, val) values = [('d_loss', d_loss), ('g_loss', g_loss)] progbar.add(len(train['id']), values=values) n += 1 if n % log_step == 0: val['predict'] = self.generator_model(val['trace']) concate = concatenate([val['trace'], val['predict']], axis=3) score = self.discriminator_model(concate) print(score) title = f'epoch{epoch + 1:0>2}_step{n:0>5}___' val['id'] = tf.convert_to_tensor( title.encode('utf-8'), dtype=tf.string)[tf.newaxis] example = next(seisnn.example_proto.batch_iterator(val)) instance = Instance(example) if plot: instance.plot() else: instance.plot(save_dir=history_path) ckpt_save_path = ckpt_manager.save() print(f'Saving checkpoint to {ckpt_save_path}') self.generator_model.save(f'/home/andy/Models/{model_name}.h5')
[docs] def train_step(self, train, val): """ Training step. :param train: Training data. :param val: Validation data. :rtype: float :return: predict loss, validation loss """ real = np.ones((train['trace'].shape[0], 1)) fake = np.zeros((train['trace'].shape[0], 1)) g_pred = self.generator_model(train['trace'], training=False) concat = concatenate((train['trace'], g_pred), axis=3) f_disc_loss = self.discriminator_model.train_on_batch(concat, fake) concat = concatenate((train['trace'], train['label']), axis=3) r_disc_loss = self.discriminator_model.train_on_batch(concat, real) disc_loss = 0.5 * (f_disc_loss + r_disc_loss) self.discriminator_model.trainable = False gen_loss = self.cgan_model.train_on_batch(train['trace'], [train['label'], real]) self.discriminator_model.trainable = True return np.array(disc_loss), np.array(gen_loss)
if __name__ == "__main__": database = 'CWB.db' db = seisnn.sql.Client(database) tfr_list = db.get_tfrecord(from_date='2000-01-01', to_date='2018-12-31', column='path') model_instance = 'qqq' trainer = GeneratorTrainer(database) trainer.train_loop(tfr_list, model_instance, batch_size=250, epochs=100, plot=True)