From e855880d79a6affa51c4b6cef8390b7b0fef5503 Mon Sep 17 00:00:00 2001 From: Safoora Yousefi Date: Fri, 17 Jan 2025 14:37:11 -0800 Subject: [PATCH] Resume from fix (#76) Co-authored-by: Safoora Yousefi --- eureka_ml_insights/core/inference.py | 30 ++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/eureka_ml_insights/core/inference.py b/eureka_ml_insights/core/inference.py index cb61b99..1374c5b 100644 --- a/eureka_ml_insights/core/inference.py +++ b/eureka_ml_insights/core/inference.py @@ -11,12 +11,21 @@ from .pipeline import Component from .reserved_names import INFERENCE_RESERVED_NAMES + MINUTE = 60 class Inference(Component): - def __init__(self, model_config, data_config, output_dir, resume_from=None, new_columns=None, requests_per_minute=None, max_concurrent=1): - + def __init__( + self, + model_config, + data_config, + output_dir, + resume_from=None, + new_columns=None, + requests_per_minute=None, + max_concurrent=1, + ): """ Initialize the Inference component. args: @@ -62,13 +71,13 @@ def fetch_previous_inference_results(self): # fetch previous results from the provided resume_from file logging.info(f"Resuming inference from {self.resume_from}") pre_inf_results_df = DataReader(self.resume_from, format=".jsonl").load_dataset() - + # add new columns listed by the user to the previous inference results if self.new_columns: for col in self.new_columns: if col not in pre_inf_results_df.columns: pre_inf_results_df[col] = None - + # validate the resume_from contents with self.data_loader as loader: _, sample_model_input = self.data_loader.get_sample_model_input() @@ -80,13 +89,17 @@ def fetch_previous_inference_results(self): # perform a sample inference call to get the model output keys and validate the resume_from contents sample_response_dict = self.model.generate(*sample_model_input) + if not sample_response_dict["is_valid"]: + raise ValueError( + "Sample inference call for resume_from returned invalid results, please check the model configuration." + ) # check if the inference response dictionary contains the same keys as the resume_from file eventual_keys = set(sample_response_dict.keys()) | set(sample_data_keys) # in case of resuming from a file that was generated by an older version of the model, # we let the discrepancy in the reserved keys slide and later set the missing keys to None - match_keys = set(pre_inf_results_df.columns) | set(INFERENCE_RESERVED_NAMES) - + match_keys = set(pre_inf_results_df.columns) | set(INFERENCE_RESERVED_NAMES) + if set(eventual_keys) != match_keys: diff = set(eventual_keys) ^ set(match_keys) raise ValueError( @@ -139,6 +152,11 @@ def retrieve_exisiting_result(self, data, pre_inf_results_df): prev_model_tokens, prev_model_time, ) + # add remaining pre_inf_results_df columns to the data point + for col in pre_inf_results_df.columns: + if col not in data: + data[col] = prev_results[col].values[0] + return data def run(self):