Skip to content

Commit

Permalink
Resume from fix (#76)
Browse files Browse the repository at this point in the history
Co-authored-by: Safoora Yousefi <[email protected]>
  • Loading branch information
safooray and Safoora Yousefi authored Jan 17, 2025
1 parent 74da19b commit e855880
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions eureka_ml_insights/core/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,21 @@

from .pipeline import Component
from .reserved_names import INFERENCE_RESERVED_NAMES

MINUTE = 60


class Inference(Component):
def __init__(self, model_config, data_config, output_dir, resume_from=None, new_columns=None, requests_per_minute=None, max_concurrent=1):

def __init__(
self,
model_config,
data_config,
output_dir,
resume_from=None,
new_columns=None,
requests_per_minute=None,
max_concurrent=1,
):
"""
Initialize the Inference component.
args:
Expand Down Expand Up @@ -62,13 +71,13 @@ def fetch_previous_inference_results(self):
# fetch previous results from the provided resume_from file
logging.info(f"Resuming inference from {self.resume_from}")
pre_inf_results_df = DataReader(self.resume_from, format=".jsonl").load_dataset()

# add new columns listed by the user to the previous inference results
if self.new_columns:
for col in self.new_columns:
if col not in pre_inf_results_df.columns:
pre_inf_results_df[col] = None

# validate the resume_from contents
with self.data_loader as loader:
_, sample_model_input = self.data_loader.get_sample_model_input()
Expand All @@ -80,13 +89,17 @@ def fetch_previous_inference_results(self):

# perform a sample inference call to get the model output keys and validate the resume_from contents
sample_response_dict = self.model.generate(*sample_model_input)
if not sample_response_dict["is_valid"]:
raise ValueError(
"Sample inference call for resume_from returned invalid results, please check the model configuration."
)
# check if the inference response dictionary contains the same keys as the resume_from file
eventual_keys = set(sample_response_dict.keys()) | set(sample_data_keys)

# in case of resuming from a file that was generated by an older version of the model,
# we let the discrepancy in the reserved keys slide and later set the missing keys to None
match_keys = set(pre_inf_results_df.columns) | set(INFERENCE_RESERVED_NAMES)
match_keys = set(pre_inf_results_df.columns) | set(INFERENCE_RESERVED_NAMES)

if set(eventual_keys) != match_keys:
diff = set(eventual_keys) ^ set(match_keys)
raise ValueError(
Expand Down Expand Up @@ -139,6 +152,11 @@ def retrieve_exisiting_result(self, data, pre_inf_results_df):
prev_model_tokens,
prev_model_time,
)
# add remaining pre_inf_results_df columns to the data point
for col in pre_inf_results_df.columns:
if col not in data:
data[col] = prev_results[col].values[0]

return data

def run(self):
Expand Down

0 comments on commit e855880

Please sign in to comment.