import gc
import time
import torch
import gpytorch
import warnings
import numpy as np
import matplotlib.pyplot as plt
from .search import Search
from .exhaustive_helpers import (
precompute_precision,
interval_posterior_from_precision,
compute_interval_pvalue,
)
from gpytorch.mlls import ExactMarginalLogLikelihood
[docs]
class ExhaustiveSearch(Search):
"""
This class implements an exhaustive search algorithm to identify anomalous intervals in a time series using Gaussian Processes.
Method:
1. List every possible contiguous interval in the time series that could contain an anomaly, based on priors (e.g., minimum and maximum anoamly duration).
2. Fit the entire time series with a GP and store the optimized parameters (and, optionally, the full precision matrix if using dynamic programming).
3. For each candidate interval, compute the posterior likelihood of the test interval given a MultivariateNormal distribution fit to the rest of the data.
Optionally, use dynamic programming to compute the posterior likelihoods more efficiently by leveraging the precision matrix of the full GP fit.
4. Flag intervals as anomalous if a metric measuring posterior likelihood is below a certain threshold (e.g. mahalanobis distance below some threshold)
"""
def __init__(
self,
x, # Map to Search.x
y, # Map to Search.y
dominant_period, # Map to Search.dominant_period
device="cpu", # Map to Search.device
min_anomaly_len=50,
max_anomaly_len=100,
window_slide_step=100,
window_size_step=50,
assume_independent=True,
which_test_metric="pval",
):
"""
Initialize the ExhaustiveSearch class and the base Search class with the provided parameters.
Args:
x (np.ndarray): x array of the light curve
y (np.ndarray): y array of the light curve
dominant_period (float): dominant period of the light curve
device (str): device to use for GP modeling (default: "cpu")
min_anomaly_len (int): minimum length of candidate anomalous intervals (default: 1)
max_anomaly_len (int): maximum length of candidate anomalous intervals (default: 400)
window_slide_step (int): step size for sliding the window across the time series (default: 1)
window_size_step (int): step size for varying the size of the candidate intervals (default: 1)
assume_independent (bool): if True, assumes independence between points for speed. False is not yet implemented and will be ignored for now. (default: True)
which_test_metric (str): metric to use for evaluating the likelihood of test intervals.
Options are 'pval', 'mahalanobis', 'nlpd', 'msll', 'rmse', 'mll', or default is 'll' (log-likelihood)
"""
# Initialize the Base Search class
# This handles self.x, self.y, self.x_tensor, self.y_tensor, and self.device
# It also intializes self.num_detected_anomalies, self.flagged_anomalous, self.flagged_anomalous_signal, and self.runtime
super().__init__(x=x, y=y, dominant_period=dominant_period, device=device)
# Initialize parameters
self.min_anomaly_len = min_anomaly_len
self.max_anomaly_len = max_anomaly_len
self.window_slide_step = window_slide_step
self.window_size_step = window_size_step
self.assume_independent = assume_independent
self.which_test_metric = which_test_metric
self.num_steps = len(x)
# Check that min_anomaly_len is at least 1 and that max_anomaly_len is at least min_anomaly_len
if min_anomaly_len < 1:
warnings.warn(
"min_anomaly_len must be at least 1. Setting min_anomaly_len to 1."
)
self.min_anomaly_len = 1
if max_anomaly_len < min_anomaly_len:
warnings.warn(
"max_anomaly_len must be at least min_anomaly_len. Setting max_anomaly_len to 10 x min_anomaly_len."
)
self.max_anomaly_len = 10 * self.min_anomaly_len
# Possible candidate intervals
self.intervals = []
for start in range(0, self.num_steps - min_anomaly_len, window_slide_step):
for end in range(
start + min_anomaly_len,
min(start + max_anomaly_len, self.num_steps),
window_size_step,
):
self.intervals.append((start, end))
# Initialize variables to store results
self.metrics = []
self.pos_or_neg_intervals = []
[docs]
def search_for_anomaly(
self,
filename="",
refit=False,
neg_anomaly_only=False,
pos_anomaly_only=False,
dynamic_programming=False,
threshold=1e-3,
num_intervals_to_flag=None,
silent=True,
plot=False,
**kwargs,
):
"""
Main function to perform the exhaustive search for anomalies in the time series data.
Args:
filename (str): If provided, saves the results to this file (default: "")
refit (bool): If true, refit the GP for each interval. If false, use the same GP for all intervals (faster but less accurate) (default: False)
neg_anomaly_only (bool): Whether to only flag negative anomalies (i.e., dips) instead of both positive and negative anomalies (default: False)
pos_anomaly_only (bool): Whether to only flag positive anomalies (i.e., flares) instead of both positive and negative anomalies (default: False)
dynamic_programming (bool): If true, use dynamic programming to find the best interval. Only works if refit = False (default: False)
threshold (float): Threshold for flagging an interval as anomalous based on the test metric (default: 1e-5)
num_intervals_to_flag (int or None): If not None, flag the top num_intervals_to_flag intervals as anomalous based on the test metric, instead of using a threshold (default: None)
silent (bool): If true, suppresses print statements during training (default: True)
plot (bool): If true, plots the GP prediction and p-value for each candidate interval (default: False)
Args:
kwargs: Additional keyword arguments to pass to the GP training function, such as training_iterations, lr, early_stopping, etc.
"""
start_time = time.time()
# Check that both threshold and num_intervals_to_flag are not provided at the same time
if threshold is not None and num_intervals_to_flag is not None:
warnings.warn(
"Both threshold and num_intervals_to_flag are provided. Only one can be used. Setting num_intervals_to_flag to None and using threshold for flagging anomalous intervals."
)
num_intervals_to_flag = None
# Check that if dynamic_programming is True, then refit must be False
if dynamic_programming and refit:
warnings.warn(
"Dynamic programming only works if refit is False. Setting refit to False."
)
refit = False
# Check if dynamic_programming is True, then which_test_metric must be "pval" or "mahalanobis"
if dynamic_programming and self.which_test_metric not in [
"pval",
"mahalanobis",
]:
warnings.warn(
"Dynamic programming only works with pval or mahalanobis metrics. Setting self.which_test_metric to 'pval'."
)
self.which_test_metric = "pval"
# Initialize
self.min_metric = np.inf
self.best_interval = None
# Write metrics to txt file if filename is provided
if filename == "":
save_to_txt = False
else:
# Create txt file to save results
save_to_txt = True
# write header
with open(filename, "w") as f:
f.write("start,end,metric\n")
# Initialize kernel, likelihood, and mean based on self.dominant_period
init_kernel = self.build_kernel()
init_mean = self.build_mean()
init_likelihood = self.build_likelihood()
# Pop min_anomaly_len and max_anomaly_len from kwargs if they are there, because they are not needed for train_gp
kwargs.pop("min_anomaly_len", None)
kwargs.pop("max_anomaly_len", None)
# If not refitting at each iteration, fit the GP to the entire data once and save the kernel parameters
if not refit:
# Build GP model on full x and y data
model, likelihood, _, _ = self.build_gp_model(
x=self.x_tensor, y=self.y_tensor
)
# Train model to get initial fit
model, likelihood, _ = self.train_gp(
gp_model=model,
likelihood=likelihood,
x=self.x_tensor,
y=self.y_tensor,
device=self.device,
**kwargs,
)
# Update kernel and mean with learned parameters
kernel = model.covar_module
mean = model.mean_module
# If using dynamic programming, precompute and cache the precision matrix for the full dataset
if dynamic_programming:
mu_full, J_full = precompute_precision(
full_x=self.x_tensor,
mean_module=mean,
kernel_module=kernel,
noise_variance=model.likelihood.noise.item(),
dtype=torch.float64, # or float32
device=self.device,
)
# Iterate over each possible anomaly interval
for start, end in self.intervals:
# Create train and test masks
mask_train = np.ones(self.num_steps, dtype=bool)
mask_train[start:end] = False
mask_test = ~mask_train
# Create train data without interval
x_train = torch.tensor(self.x[mask_train], dtype=torch.float32).to(
self.device
)
y_train = torch.tensor(self.y[mask_train], dtype=torch.float32).to(
self.device
)
# Create test data with interval
x_test = torch.tensor(self.x[mask_test], dtype=torch.float32).to(
self.device
)
y_test = torch.tensor(self.y[mask_test], dtype=torch.float32).to(
self.device
)
# If refitting at each iteration, fit the GP to the training data without the interval
if refit:
if "model" in locals():
del likelihood
model, likelihood, _, _ = self.build_gp_model(
x_train,
y_train,
kernel=init_kernel,
mean=init_mean,
likelihood=init_likelihood,
device=self.device,
) # Note: initializing from previous hyperparameters
# Train GP on training data
model, likelihood, _ = self.train_gp(
gp_model=model,
likelihood=likelihood,
x=x_train,
y=y_train,
device=self.device,
**kwargs,
)
# Evaluate metric for prediction on test data
model.eval()
likelihood.eval()
# Compute p-value for the interval
with torch.no_grad(), gpytorch.settings.fast_pred_var():
f_pred = model(x_test)
y_pred = likelihood(f_pred)
else:
if dynamic_programming:
# Use precomputed precision to get posterior on test interval
y_pred = interval_posterior_from_precision(
mu=mu_full,
J=J_full,
full_y=self.y_tensor,
mask_train=mask_train,
mask_test=mask_test,
dtype=J_full.dtype,
)
else:
# Don't refit; but create a new GP model over train data with previously optimized parameters
model, likelihood, _, _ = self.build_gp_model(
x_train,
y_train,
kernel=kernel,
mean=mean,
likelihood=likelihood,
device=self.device,
) # Note: initializing from previous hyperparameters
# Evaluate metric for prediction on test data
model.eval()
likelihood.eval()
# Compute p-value for the interval
with torch.no_grad(), gpytorch.settings.fast_pred_var():
f_pred = model(x_test)
y_pred = likelihood(f_pred)
# Store as positive or negative interval based on whether the mean prediction is above or below the observed values in the test interval
if np.mean(y_test.cpu().numpy() - y_pred.mean.cpu().numpy()) <= 0:
self.pos_or_neg_intervals.append(-1)
else:
self.pos_or_neg_intervals.append(1)
# Depending on which_test_metric, compute the appropriate metric to evaluate the likelihood of the test interval under the model trained on the train interval
if (
self.which_test_metric == "pval"
or self.which_test_metric == "mahalanobis"
):
maha_dist, p_value = compute_interval_pvalue(y_test, y_pred)
if self.which_test_metric == "pval":
interval_metric = p_value
else:
interval_metric = maha_dist
else:
# These metrics are computed pointwise, so we compute them for each point in the interval and then average them to get a single metric for the interval
metric_sum = 0
for i in range(end - start):
# For each point in the interval, calculate the metric and sum them up
x_curr = x_test[i].unsqueeze(0)
y_curr = y_test[i].unsqueeze(0)
with torch.no_grad(), gpytorch.settings.fast_pred_var():
f_pred = model(x_curr)
y_pred = likelihood(f_pred)
if self.which_test_metric == "nlpd":
metric = gpytorch.metrics.negative_log_predictive_density(
y_pred, y_curr
)
elif self.which_test_metric == "msll":
metric = gpytorch.metrics.mean_standardized_log_loss(
y_pred, y_curr
)
elif self.which_test_metric == "rmse":
pred_mean = y_pred.mean.cpu().numpy()
metric = np.sqrt(
np.mean((pred_mean - y_curr.cpu().numpy()) ** 2)
)
elif self.which_test_metric == "mll":
mll_func = ExactMarginalLogLikelihood(likelihood, model)
metric = mll_func(f_pred, y_curr)
else: # Default to log-likelihood
metric = y_pred.log_prob(y_curr)
metric_sum += metric
# Calculate the mean of the metric over the interval
interval_metric = metric_sum / (end - start)
# Save the metric for the interval
self.metrics.append(interval_metric)
# Check if the current interval is the best one
if interval_metric < self.min_metric:
self.min_metric = interval_metric
self.best_interval = (start, end)
# Print results for the interval if not silent
if not silent:
print(
f"Anomaly interval: {start}-{end}, metric {self.which_test_metric} over the interval: {interval_metric}, pos or neg: {self.pos_or_neg_intervals[-1]}"
)
if plot:
if dynamic_programming:
# Create a train GP model just to get the kernel and mean for plotting; but we won't use it for inference
model = model, likelihood, _, _ = self.build_gp_model(
x_train,
y_train,
kernel=kernel,
mean=mean,
likelihood=likelihood,
device=self.device,
)
# Evaluate metric for prediction on test data
model.eval()
likelihood.eval()
# Compute predictions for plotting
with torch.no_grad(), gpytorch.settings.fast_pred_var():
f_pred = model(self.x_tensor)
y_pred = likelihood(f_pred)
pred_mean = y_pred.mean.cpu().numpy()
one_stdev = y_pred.stddev.cpu().numpy()
# Plot the results
plt.figure(figsize=(8, 5))
plt.title(f"p-value: {interval_metric:.0e}")
plt.fill_between(
self.x, pred_mean - one_stdev, pred_mean + one_stdev, alpha=0.5
)
plt.scatter(self.x, self.y, c="black", s=3, alpha=0.5, label="Observed")
plt.scatter(
self.x[start:end],
self.y[start:end],
c="red",
marker="x",
s=10,
alpha=0.8,
label="Held Out Interval",
)
plt.plot(
self.x,
pred_mean,
lw=1,
alpha=0.7,
label=r"Predicted $\pm$ 1 $\sigma$",
)
plt.xlabel("Time")
plt.ylabel("Flux")
plt.xlim(min(self.x), max(self.x))
plt.legend()
plt.show()
# Save results to txt if save_to_txt is True
if save_to_txt:
with open(filename, "a") as f:
f.write(f"{start},{end},{interval_metric}\n")
# Delete all variables to free up memory
if "model" in locals():
del model
if "f_pred" in locals():
del f_pred
del x_train
del y_train
del y_pred
del interval_metric
del mask_train
del mask_test
del x_test
del y_test
gc.collect()
torch.cuda.empty_cache()
# After iterating through all intervals, update self.num_detected_anomalies, self.flagged_anomalous, and self.anomalous_signal based on the intervals found
if threshold is not None:
# Flag intervals as anomalous if their metric is below the threshold
for i, metric in enumerate(self.metrics):
# If metric is a tensor, convert to scalar
if isinstance(metric, torch.Tensor):
metric = metric.detach().cpu().numpy().item()
if metric < threshold:
start, end = self.intervals[i]
self.flagged_anomalous[start:end] = True
self.anomalous_signal[start:end] = metric
self.num_detected_anomalies += 1
elif num_intervals_to_flag is not None:
# Flag the top num_intervals_to_flag intervals as anomalous based on the metric
sorted_indices = np.argsort(self.metrics)
for i in range(min(num_intervals_to_flag, len(self.intervals))):
idx = sorted_indices[i]
start, end = self.intervals[idx]
self.flagged_anomalous[start:end] = True
self.anomalous_signal[start:end] = self.metrics[idx]
self.num_detected_anomalies += 1
self.runtime = time.time() - start_time