Source code for nahiku.greedy_search

import time
import torch
import gpytorch
import warnings
import numpy as np
import matplotlib.pyplot as plt

from .search import Search

from gpytorch.mlls import ExactMarginalLogLikelihood
from scipy.ndimage import minimum_filter1d


[docs] class GreedySearch(Search): """ This class implements a greedy search algorithm to identify anomalous intervals in a time series using Gaussian Processes. Method: 1. Perform GP regression on the time series. 2. Find the most significant outlier interval (based on sum of residuals) of length len_deviant. 3. Exclude the outlier interval and redo regression. See if GP improves by some threshold. 4. Expand outlier interval in both directions by expansion_param and redo step 3. 5. Repeat step 4 as long as GP improves the fit by some threshold. 6. If no improvement, define anomaly signal as the difference between data and regression in the outlier interval of points. 7. Repeat steps 2-6 while there are still points above the num_sigma_threshold. """ def __init__( self, x, # Map to Search.x y, # Map to Search.y dominant_period, # Map to Search.dominant_period device="cpu", # Map to Search.device which_grow_metric="mll", y_err=None, num_sigma_threshold=3, expansion_param=1, len_deviant=1, ): """ Initialize the GreedySearch class and the base Search class with the provided parameters. Args: x (np.ndarray): x array of the light curve y (np.ndarray): y array of the light curve dominant_period (float): dominant period of the light curve device (str): device to use for GP modeling (default: "cpu") which_grow_metric (str): Metric to use for evaluating improvement when expanding the anomalous region. Options are 'nlpd', 'msll', 'rmse', 'mll'. Default is 'mll'. y_err (np.ndarray or None): 1D array of observational errors. If None, assumes zero error for all points. num_sigma_threshold (float): Threshold in terms of standard deviations for identifying anomalies. Default is 3. expansion_param (int): Number of indices to expand the anomalous region on each side during the greedy search. Default is 1. len_deviant (int): Length of the interval to consider for identifying the most significant outlier. Default is 1. """ # Initialize the Base Search class # This handles self.x, self.y, self.x_tensor, self.y_tensor, and self.device # It also intializes self.num_detected_anomalies, self.flagged_anomalous, self.flagged_anomalous_signal, and self.runtime super().__init__(x=x, y=y, dominant_period=dominant_period, device=device) # If y_err is not provided, set y_err = 0 for all points if y_err is None: print( "No y_err provided. Using y_err = 0 for all points. Note that y_err is only used for calculating residuals, not in GP training." ) self.y_err = np.zeros_like(y) else: self.y_err = y_err # Save copies of x, y, and y_err self.x_orig = x self.y_orig = y self.y_err_orig = y_err # Initialize variables self.num_sigma_threshold = num_sigma_threshold self.expansion_param = expansion_param self.which_grow_metric = which_grow_metric # Check that len_deviant is valid if len_deviant <= 0: warnings.warn( "len_deviant must be greater than 0. Setting len_deviant = 1." ) self.len_deviant = 1 else: self.len_deviant = len_deviant # Initialize threshold to None or zeros self.threshold = np.zeros_like(self.x)
[docs] def plot_greedy(self, x, pred_mean, left_edge, right_edge, residuals): """ Plot the light curve, GP fit, and detected anomalies at each iteration of the greedy search. Args: x (torch.Tensor): x values used for GP prediction at the current iteration. Only needed to plot against pred_mean Every other plotting call will use self.x or self.x_orig pred_mean (np.ndarray): Array of the GP mean predictions corresponding to self.x at the current iteration left_edge (int): Left edge index of the currently flagged anomalous region right_edge (int): Right edge index of the currently flagged anomalous region residuals (np.ndarray): Array of residuals (absolute value of observed - predicted) for the current iteration """ fig, axs = plt.subplots(2, 1, sharex=True, figsize=(8, 8)) # Plot the GP mean prediction vs. data axs[0].axvspan( self.x[left_edge], self.x[min(right_edge, len(self.x) - 1)], color="green", alpha=0.4, label="New Flagged Anomaly", ) axs[0].scatter( self.x_orig, self.y_orig, c="black", s=3, alpha=0.5, label="Observed" ) axs[0].scatter( self.x_orig[left_edge:right_edge], self.y_orig[left_edge:right_edge], c="red", s=10, marker="x", alpha=0.8, ) axs[0].scatter( self.x_orig[(self.flagged_anomalous == 1)], self.y_orig[(self.flagged_anomalous == 1)], c="red", s=10, marker="x", alpha=0.8, label="Flagged as Anomalous", ) axs[0].plot(x, pred_mean, lw=1, alpha=0.7, label="GP Mean Prediction") axs[0].set_ylim(np.min(self.y_orig), np.max(self.y_orig)) axs[0].legend() axs[0].set_xlim(np.min(self.x_orig), np.max(self.x_orig)) axs[0].set_xlabel("Time") axs[0].set_ylabel("Flux") # Plot the residuals axs[1].axvspan( self.x[left_edge], self.x[right_edge], color="green", alpha=0.4, label="New Flagged Anomaly", ) axs[1].plot( self.x, self.threshold, "--", lw=3, alpha=0.8, color="gold", label=f"Threshold = {self.num_sigma_threshold} " + r"$\sqrt{\text{var} + \text{err}^2}$", ) axs[1].scatter( self.x, residuals, c="black", s=3, alpha=0.5, label="|Observed - GP Mean Prediction|", ) axs[1].scatter( self.x[left_edge:right_edge], residuals[left_edge:right_edge], c="red", s=10, marker="x", alpha=0.8, label="New Flagged as Anomalous", ) axs[1].legend() axs[1].set_xlim(np.min(self.x_orig), np.max(self.x_orig)) axs[1].set_xlabel("Time") axs[1].set_ylabel("Flux") plt.tight_layout() plt.show()
[docs] def search_for_anomaly( self, refit=False, neg_anomaly_only=False, pos_anomaly_only=False, plot=False, detection_range=None, update_threshold=False, **kwargs, ): """ Main function to perform the greedy search for anomalies in the time series data. Args: refit (bool): Whether to refit the GP model at each iteration of the greedy search (default: True) neg_anomaly_only (bool): Whether to only flag negative anomalies (i.e., dips) instead of both positive and negative anomalies (default: False) pos_anomaly_only (bool): Whether to only flag positive anomalies (i.e., flares) instead of both positive and negative anomalies (default: False) plot (bool): Whether to the light curve, GP fit, and detected anomalies at each iteration of the greedy search (default: False) detection_range (tuple or None): Tuple specifying the range of x values to consider for anomaly detection. If None, considers the entire range of x. Default is None. update_threshold (bool): Whether to update the num_sigma_threshold after each detected anomaly Args: kwargs: Additional keyword arguments to pass to the GP training function, such as training_iterations, lr, early_stopping, etc. """ start_time = time.time() # Build GP model on full x and y data model, likelihood, _, _ = self.build_gp_model(x=self.x_tensor, y=self.y_tensor) # Train model to get initial fit model, likelihood, _ = self.train_gp( gp_model=model, likelihood=likelihood, x=self.x_tensor, y=self.y_tensor, device=self.device, **kwargs, ) # Update kernel and mean with learned parameters kernel = model.covar_module mean = model.mean_module # Get mean prediction from the learned model model.eval() likelihood.eval() with torch.no_grad(), gpytorch.settings.fast_pred_var(): observed_pred = model(self.x_tensor) # Calculate the threshold for flagging anomalies based on the residuals and y_err residuals = np.abs(self.y - observed_pred.mean.cpu().numpy()) residual_var = np.var(residuals) sum_variances = self.y_err**2 + residual_var self.threshold = self.num_sigma_threshold * np.sqrt(sum_variances) exist_points_above_threshold = True # Step 7 (repeat steps 2-6 while there are still points above the num_sigma_threshold) while exist_points_above_threshold: model, likelihood, _, _ = self.build_gp_model( x=self.x_tensor, y=self.y_tensor, kernel=kernel, mean=mean, likelihood=likelihood, device=self.device, ) # Note: initializing from previous hyperparameters # Re-fit the GP on non-anomalous data if refit: model, likelihood, _ = self.train_gp( gp_model=model, likelihood=likelihood, x=self.x_tensor, y=self.y_tensor, device=self.device, **kwargs, ) # Get mean prediction from the learned model over x and x_orig model.eval() likelihood.eval() with torch.no_grad(), gpytorch.settings.fast_pred_var(): observed_pred = likelihood(model(self.x_tensor)) pred_mean = observed_pred.mean.cpu().numpy() pred_full_x = ( model( torch.tensor(self.x_orig, dtype=torch.float32).to(self.device) ) .mean.cpu() .numpy() ) # Compute the minimum value in each window of size len_deviant in the residuals array residuals = np.abs(pred_mean - self.y) min_values = minimum_filter1d( residuals, size=self.len_deviant, mode="nearest" ) max_min_idx = np.argmax(min_values) # Intialize variables for expanding anomalous region left_edge = max_min_idx right_edge = max_min_idx + self.len_deviant diff_metric = float("inf") metric = float("inf") # Plot if plot: self.plot_greedy(self.x, pred_mean, left_edge, right_edge, residuals) # While the metric is decreasing, expand the anomalous edges while diff_metric > 0: # Subset x, y, and y_err by left_edge and right_edge # Do not need to worry about masking by anomalous, because anomalous points are removed at the end of the loop subset = (np.arange(len(self.x)) > right_edge) | ( np.arange(len(self.x)) < left_edge ) x_sub = torch.tensor(self.x[subset], dtype=torch.float32).to( self.device ) y_sub = torch.tensor(self.y[subset], dtype=torch.float32).to( self.device ) model, likelihood, _, _ = self.build_gp_model( x=x_sub, y=y_sub, kernel=kernel, mean=mean, likelihood=likelihood, device=self.device, ) # Note: initializing from previous hyperparameters # Re-fit the GP on non-anomalous data if refit: model, likelihood, _ = self.train_gp( gp_model=model, likelihood=likelihood, x=x_sub, y=y_sub, device=self.device, **kwargs, ) # Predict on the subset and on the full x_orig model.eval() likelihood.eval() with torch.no_grad(), gpytorch.settings.fast_pred_var(): observed_pred_sub = likelihood(model(x_sub)) pred_mean_sub = observed_pred_sub.mean.cpu().numpy() pred_full_x = ( model( torch.tensor(self.x_orig, dtype=torch.float32).to( self.device ) ) .mean.cpu() .numpy() ) # Calculate metric difference old_metric = metric # NLPD loss if self.which_grow_metric == "nlpd": metric = gpytorch.metrics.negative_log_predictive_density( observed_pred_sub, y_sub ) # MSLL loss elif self.which_grow_metric == "msll": metric = gpytorch.metrics.mean_standardized_log_loss( observed_pred_sub, y_sub ) # RMSE loss elif self.which_grow_metric == "rmse": metric = np.sqrt( np.mean((pred_mean_sub - y_sub.cpu().numpy()) ** 2) ) # MLL loss else: with torch.no_grad(), gpytorch.settings.fast_pred_var(): output = model(x_sub) mll_func = ExactMarginalLogLikelihood(likelihood, model) metric = mll_func(output, y_sub) diff_metric = old_metric - metric # smaller is better # Expand left_edge and right_edge by expansion_param if left_edge >= (0 + self.expansion_param): left_edge -= self.expansion_param if right_edge < (len(self.x) - self.expansion_param): right_edge += self.expansion_param # Plot if plot: self.plot_greedy( x_sub, pred_mean_sub, left_edge, right_edge, residuals ) # Remove left_edge:right_edge from x, y, and y_err for the next iteration of the greedy search # Handle case where left_edge = 0 or right_edge = len(self.x) self.x = np.delete( self.x, np.arange(min(left_edge, 0), max(right_edge, len(self.x))) ) self.y = np.delete( self.y, np.arange(min(left_edge, 0), max(right_edge, len(self.y))) ) self.y_err = np.delete( self.y_err, np.arange(min(left_edge, 0), max(right_edge, len(self.y_err))), ) self.x_tensor = torch.tensor(self.x, dtype=torch.float32).to(self.device) self.y_tensor = torch.tensor(self.y, dtype=torch.float32).to(self.device) # Update num_detected_anomalies, flagged_anomalous, and anomalous_signal with the new flagged anomaly from left_edge to right_edge if neg_anomaly_only: # Check if the average of the residuals in the flagged region is negative if ( np.mean( self.y_orig[left_edge:right_edge] - pred_full_x[left_edge:right_edge] ) <= 0 ): self.num_detected_anomalies += 1 self.flagged_anomalous[left_edge:right_edge] = 1 self.anomalous_signal[left_edge:right_edge] = minimum_filter1d( np.abs( self.y_orig[left_edge:right_edge] - pred_full_x[left_edge:right_edge] ), size=self.len_deviant, mode="nearest", ) print(f"Anomalous edges = {left_edge}:{right_edge}") else: print( f"Not flagging edges {left_edge}:{right_edge} because not a negative anomaly (mean(truth - pred) = {np.mean(self.y_orig[left_edge:right_edge] - pred_full_x[left_edge:right_edge])}), and neg_anomaly_only is {neg_anomaly_only}. Still will remove edges from GP fit." ) elif pos_anomaly_only: # Check if the average of the residuals in the flagged region is positive if ( np.mean( self.y_orig[left_edge:right_edge] - pred_full_x[left_edge:right_edge] ) >= 0 ): self.num_detected_anomalies += 1 self.flagged_anomalous[left_edge:right_edge] = 1 self.anomalous_signal[left_edge:right_edge] = minimum_filter1d( np.abs( self.y_orig[left_edge:right_edge] - pred_full_x[left_edge:right_edge] ), size=self.len_deviant, mode="nearest", ) print(f"Anomalous edges = {left_edge}:{right_edge}") else: print( f"Not flagging edges {left_edge}:{right_edge} because not a positive anomaly (mean(truth - pred) = {np.mean(self.y_orig[left_edge:right_edge] - pred_full_x[left_edge:right_edge])}), and pos_anomaly_only is {pos_anomaly_only}. Still will remove edges from GP fit." ) else: # Flag as anomalous regardless of whether it's a positive or negative anomaly self.num_detected_anomalies += 1 self.flagged_anomalous[left_edge:right_edge] = 1 self.anomalous_signal[left_edge:right_edge] = minimum_filter1d( np.abs( self.y_orig[left_edge:right_edge] - pred_full_x[left_edge:right_edge] ), size=self.len_deviant, mode="nearest", ) print(f"Anomalous edges = {left_edge}:{right_edge}") # Predict on reduced x_tensor to get new residuals for threshold checking model.eval() likelihood.eval() with torch.no_grad(), gpytorch.settings.fast_pred_var(): observed_pred = model(self.x_tensor) residuals = np.abs(self.y - observed_pred.mean.cpu().numpy()) if update_threshold: # Calculate threshold residual_var = np.var(residuals) sum_variances = self.y_err**2 + residual_var self.threshold = self.num_sigma_threshold * np.sqrt(sum_variances) else: # Remove points between left_edge and right_edge from the threshold self.threshold = np.delete( self.threshold, np.arange(min(left_edge, 0), max(right_edge, len(self.threshold))), ) # Compute the minimum value in each window of size len_deviant in the residuals array; check if there are still points above the threshold min_values = minimum_filter1d( residuals, size=self.len_deviant, mode="nearest" ) exist_points_above_threshold = np.any(min_values > self.threshold) self.runtime = time.time() - start_time