Skip to content

Commit

Permalink
docs + formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Safoora Yousefi committed Oct 26, 2024
1 parent 7eff14a commit b933ec1
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions eureka_ml_insights/core/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@ def __init__(self, model_config, data_config, output_dir, resume_from=None, n_ca
self.model = model_config.class_name(**model_config.init_args)
self.data_loader = data_config.class_name(**data_config.init_args)
self.writer = JsonLinesWriter(os.path.join(output_dir, "inference_result.jsonl"))

self.resume_from = resume_from
if resume_from and not os.path.exists(resume_from):
raise FileNotFoundError(f"File {resume_from} not found.")

# rate limiting parameters
self.n_calls_per_min = n_calls_per_min
self.max_concurrent = max_concurrent
self.call_times = deque()
self.period = MINUTE

# parallel inference parameters
self.max_concurrent = max_concurrent

@classmethod
def from_config(cls, config):
return cls(
Expand Down Expand Up @@ -79,13 +83,16 @@ def fetch_previous_inference_results(self):
return pre_inf_results_df, last_uid

def validate_response_dict(self, response_dict):
# "model_output" and "is_valid" are mandatory fields by any inference component
# Validate that the response dictionary contains the required fields
# "model_output" and "is_valid" are mandatory fields to be returned by any model
if "model_output" not in response_dict or "is_valid" not in response_dict:
raise ValueError("Response dictionary must contain 'model_output' and 'is_valid' keys.")

def retrieve_exisiting_result(self, data, pre_inf_results_df):
# if resume_from file is provided and valid inference results
# for the current data point are present in it, use them.
"""Finds the previous result for the given data point from the pre_inf_results_df and returns it if it is valid
data: dict, data point to be inferenced
pre_inf_results_df: pd.DataFrame, previous inference results
"""
prev_results = pre_inf_results_df[pre_inf_results_df.uid == data["uid"]]
prev_result_is_valid = bool(prev_results["is_valid"].values[0])
prev_model_output = prev_results["model_output"].values[0]
Expand All @@ -101,6 +108,7 @@ def run(self):
self._run()

def _run(self):
"""sequential inference"""
if self.resume_from:
pre_inf_results_df, last_uid = self.fetch_previous_inference_results()
with self.data_loader as loader:
Expand Down Expand Up @@ -130,10 +138,16 @@ def _run(self):
writer.write(data)

async def run_in_excutor(self, model_inputs, executor):
"""Run model.generate in a ThreadPoolExecutor.
args:
model_inputs (tuple): inputs to the model.generate function.
executor (ThreadPoolExecutor): ThreadPoolExecutor instance.
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, self.model.generate, *model_inputs)

async def _run_par(self):
"""parallel inference"""
concurrent_inputs = []
concurrent_metadata = []
if self.resume_from:
Expand Down Expand Up @@ -163,6 +177,13 @@ async def _run_par(self):
await self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor)

async def run_batch(self, concurrent_inputs, concurrent_metadata, writer, executor):
"""Run a batch of inferences concurrently using ThreadPoolExecutor.
args:
concurrent_inputs (list): list of inputs to the model.generate function.
concurrent_metadata (list): list of metadata corresponding to the inputs.
writer (JsonLinesWriter): JsonLinesWriter instance to write the results.
executor (ThreadPoolExecutor): ThreadPoolExecutor instance.
"""
tasks = [asyncio.create_task(self.run_in_excutor(input_data, executor)) for input_data in concurrent_inputs]
results = await asyncio.gather(*tasks)
for i in range(len(concurrent_inputs)):
Expand Down

0 comments on commit b933ec1

Please sign in to comment.