Source code for seisnn.components

import itertools
import os

import numpy as np
import obspy
from obspy import UTCDateTime

import seisnn.core
import seisnn.example_proto
import seisnn.io
import seisnn.utils


[docs]class TFRecordConverter: """ Main class for TFRecord Converter. Consumes data from external source and emit TFRecord. """
[docs] def __init__(self, phase=('P', 'S', 'N'), trace_length=30, shape='triang'): self.phase = phase self.trace_length = trace_length self.shape = shape
[docs] def convert_training_from_picks(self, pick_list, tag, database,cpu_count = None): """ Convert training TFRecords from database picks. :param pick_list: List of picks from Pick SQL query. :param str tag: Pick tag in SQL database. :param str database: SQL database name. """ pick_list = sorted(pick_list, key=lambda pick: [pick.station, pick.time]) pick_groupby = itertools.groupby( pick_list, key=lambda pick: [pick.station, UTCDateTime(pick.time).julday]) group_picks = [[item for item in data] for (key, data) in pick_groupby] seisnn.utils.parallel(group_picks, func=self.write_tfrecord, sub_dir='train', tag=tag, database=database, batch_size=1, cpu_count = cpu_count)
def write_tfrecord(self, picks, sub_dir, tag, database): instance_list = self.get_instance_list(picks, tag, database) if instance_list: feature_list = [instance.to_feature() for instance in instance_list] example_list = [seisnn.example_proto.feature_to_example(feature) for feature in feature_list] tfr_dir = instance_list[0].get_tfrecord_dir(sub_dir) seisnn.utils.make_dirs(tfr_dir) file_name = instance_list[0].get_tfrecord_name() save_file = os.path.join(tfr_dir, file_name) seisnn.io.write_tfrecord(example_list, save_file) print(f'output {file_name}')
[docs] def get_instance_list(self, picks, tag, database): """ Returns instance list form list of picks and SQL database. :param picks: List of picks. :param str tag: Pick tag in SQL database. :param str database: SQL database root. :return: """ metadata = self.get_time_window(anchor_time=UTCDateTime(0), station='', ) instance_list = [] try: for pick in picks: if metadata.starttime < pick.time < metadata.endtime: continue elif metadata.endtime < pick.time < metadata.endtime + 30: metadata = self.get_time_window( anchor_time=metadata.starttime + 30, station=pick.station, ) else: metadata = self.get_time_window(anchor_time=pick.time, station=pick.station, shift='random') streams = seisnn.io.read_sds(metadata,sds_path='/home/andy/mseed') for _, stream in streams.items(): stream = self.signal_preprocessing(stream) instance = seisnn.core.Instance(stream) instance.label = seisnn.core.Label(instance.metadata, self.phase) instance.label.generate_label(database, tag, self.shape) instance.predict = seisnn.core.Label(instance.metadata, self.phase) instance_list.append(instance) except Exception as e: print(f'station = {pick.station}, time = {pick.time}, error = {e}') return instance_list
[docs] def get_time_window(self, anchor_time, station, shift=0): """ Returns metadata from anchor time. :param anchor_time: Anchor of the time window. :param str station: Station name. :param float or str shift: (Optional.) Shift in sec, if 'random' will shift randomly within the trace length. :rtype: dict :return: Metadata object. """ if shift == 'random': rng = np.random.default_rng() shift = rng.random() * self.trace_length metadata = seisnn.core.Metadata() metadata.starttime = obspy.UTCDateTime(anchor_time) - shift metadata.endtime = metadata.starttime + self.trace_length metadata.station = station return metadata
[docs] def signal_preprocessing(self, stream): """ Return a signal processed stream. :param obspy.Stream stream: Stream object. :rtype: obspy.Stream :return: Processed stream. """ stream.detrend('demean') stream.detrend('linear') stream.normalize() stream.resample(100) stream = self.trim_trace(stream) return stream
[docs] @staticmethod def trim_trace(stream, points=3008): """ Return trimmed stream in a given length. :param obspy.Stream stream: Stream object. :param int points: Trace data length. :rtype: obspy.Stream :return: Trimmed stream. """ trace = stream[0] start_time = trace.stats.starttime if trace.data.size > 1: dt = (trace.stats.endtime - trace.stats.starttime) / ( trace.data.size - 1) end_time = start_time + dt * (points - 1) elif trace.data.size == 1: end_time = start_time else: print('No data points in trace') return stream.trim(start_time, end_time, nearest_sample=True, pad=True, fill_value=0) return stream