Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel inf #41

Merged
merged 19 commits into from
Oct 30, 2024
1 change: 1 addition & 0 deletions .github/workflows/eureka-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ jobs:
run: |
pip show eureka_ml_insights
export skip_tests_with_missing_ds=1
export skip_slow_tests=1
pwd
make test
30 changes: 20 additions & 10 deletions eureka_ml_insights/configs/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from dataclasses import dataclass, field
from typing import Any, Type, TypeVar, List
from typing import Any, List, Type, TypeVar

UtilityClassConfigType = TypeVar("UtilityClassConfigType", bound=Type["UtilityClassConfig"])
ComponentConfigType = TypeVar("ComponentConfigType", bound=Type["ComponentConfig"])
Expand All @@ -10,11 +10,12 @@

@dataclass
class UtilityClassConfig:
""" Base class for all utility class configs
"""Base class for all utility class configs
Args:
class_name (Any): The utility class to be used with this config
init_args (dict): The arguments to be passed to the utility class constructor
"""

class_name: Any = None
init_args: dict = field(default_factory=dict)

Expand Down Expand Up @@ -52,70 +53,78 @@ class AggregatorConfig(UtilityClassConfig):

@dataclass
class ComponentConfig:
""" Base class for all component configs
"""Base class for all component configs
Args:
component_type (Any): The component class to be used with this config
output_dir (str): The directory to save the output files of this component
"""

component_type: Any = None
output_dir: str = None


@dataclass
class DataProcessingConfig(ComponentConfig):
""" Config class for the data processing component
"""Config class for the data processing component
Args:
data_reader_config (UtilityClassConfig): The data reader config to be used with this component
output_data_columns (list): List of columns (subset of input columns) to keep in the transformed data output file
"""

data_reader_config: UtilityClassConfigType = None
output_data_columns: List[str] = None


@dataclass
class PromptProcessingConfig(DataProcessingConfig):
""" Config class for the prompt processing component
"""Config class for the prompt processing component
Args:
prompt_template_path (str): The path to the prompt template jinja file
ignore_failure (bool): Whether to ignore the failures in the prompt processing and move on
"""

prompt_template_path: str = None
ignore_failure: bool = False


@dataclass
class DataJoinConfig(DataProcessingConfig):
""" Config class for the data join component
"""Config class for the data join component
Args:
other_data_reader_config (UtilityClassConfig): The data reader config for the dataset to be joined with the main dataset
pandas_merge_args (dict): Arguments to be passed to pandas merge function
"""

other_data_reader_config: UtilityClassConfigType = None
pandas_merge_args: dict = None


@dataclass
class InferenceConfig(ComponentConfig):
""" Config class for the inference component
"""Config class for the inference component
Args:
data_loader_config (UtilityClassConfig): The data loader config to be used with this component
model_config (UtilityClassConfig): The model config to be used with this component
resume_from (str): Optional. Path to the file where previous inference results are stored
"""

data_loader_config: UtilityClassConfigType = None
model_config: UtilityClassConfigType = None
resume_from: str = None
n_calls_per_min: int = None
max_concurrent: int = 1


@dataclass
class EvalReportingConfig(ComponentConfig):
""" Config class for the evaluation reporting component
"""Config class for the evaluation reporting component
Args:
data_reader_config (UtilityClassConfig): The data reader config to configure the data reader for this component
metric_config (UtilityClassConfig): The metric config
metric_config (UtilityClassConfig): The metric config
aggregator_configs (list): List of aggregator configs
visualizer_configs (list): List of visualizer configs
"""

data_reader_config: UtilityClassConfigType = None
metric_config: UtilityClassConfigType = None
aggregator_configs: List[UtilityClassConfigType] = field(default_factory=list)
Expand All @@ -127,11 +136,12 @@ class EvalReportingConfig(ComponentConfig):

@dataclass
class PipelineConfig:
""" Config class for the pipeline class
"""Config class for the pipeline class
Args:
component_configs (list): List of ComponentConfigs
log_dir (str): The directory to save the logs of the pipeline
"""

component_configs: list[ComponentConfigType] = field(default_factory=list)
log_dir: str = None

Expand Down
9 changes: 7 additions & 2 deletions eureka_ml_insights/core/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import numpy as np

from .pipeline import Component
from .reserved_names import (
INFERENCE_RESERVED_NAMES,
PROMPT_PROC_RESERVED_NAMES,
)


def compute_hash(val: str) -> str:
Expand Down Expand Up @@ -67,7 +71,8 @@ def __init__(
data_reader_config: DataReaderConfig
output_dir: str directory to save the output files of this component.
output_data_columns: Optional[List[str]] list of columns (subset of input columns)
to keep in the transformed data output file.
to keep in the transformed data output file. The columns reserved for the Eureka framework
will automatically be added to the output_data_columns if not provided.
"""
super().__init__(output_dir)
self.data_reader = data_reader_config.class_name(**data_reader_config.init_args)
Expand All @@ -88,7 +93,7 @@ def get_desired_columns(self, df):
self.output_data_columns = list(self.output_data_columns)
# if the data was multiplied, keep the columns that are needed to identify datapoint and replicates
# (just in case the user forgot to specify these columns in output_data_columns)
cols_to_keep = ["data_point_id", "data_repeat_id"]
cols_to_keep = set(INFERENCE_RESERVED_NAMES + PROMPT_PROC_RESERVED_NAMES)
self.output_data_columns.extend([col for col in cols_to_keep if col in df.columns])
self.output_data_columns = list(set(self.output_data_columns))
return df[self.output_data_columns]
Expand Down
179 changes: 155 additions & 24 deletions eureka_ml_insights/core/inference.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,58 @@
import asyncio
import logging
import os
import time
from collections import deque
from concurrent.futures import ThreadPoolExecutor

from tqdm import tqdm

from eureka_ml_insights.data_utils.data import DataReader, JsonLinesWriter

from .pipeline import Component

MINUTE = 60


class Inference(Component):
def __init__(self, model_config, data_config, output_dir, resume_from=None):
def __init__(self, model_config, data_config, output_dir, resume_from=None, n_calls_per_min=None, max_concurrent=1):
"""
Initialize the Inference component.
args:
model_config (dict): ModelConfig object.
data_config (dict): DataSetConfig object.
output_dir (str): Directory to save the inference results.
resume_from (str): optional. Path to the file where previous inference results are stored.
n_calls_per_min (int): optional. Number of calls to be made per minute, used for rate limiting. If not provided, rate limiting will not be applied.
nushib marked this conversation as resolved.
Show resolved Hide resolved
max_concurrent (int): optional. Maximum number of concurrent inferences to run. Default is 1.
"""
super().__init__(output_dir)
self.model = model_config.class_name(**model_config.init_args)
self.data_loader = data_config.class_name(**data_config.init_args)
self.writer = JsonLinesWriter(os.path.join(output_dir, "inference_result.jsonl"))

self.resume_from = resume_from
if resume_from and not os.path.exists(resume_from):
raise FileNotFoundError(f"File {resume_from} not found.")

# rate limiting parameters
self.n_calls_per_min = n_calls_per_min
self.call_times = deque()
self.period = MINUTE

# parallel inference parameters
self.max_concurrent = max_concurrent

@classmethod
def from_config(cls, config):
return cls(config.model_config, config.data_loader_config, config.output_dir, config.resume_from)
return cls(
config.model_config,
config.data_loader_config,
config.output_dir,
resume_from=config.resume_from,
n_calls_per_min=config.n_calls_per_min,
max_concurrent=config.max_concurrent,
)

def fetch_previous_inference_results(self):
# fetch previous results from the provided resume_from file
Expand All @@ -37,49 +61,156 @@ def fetch_previous_inference_results(self):

# validate the resume_from contents
with self.data_loader as loader:
sample_data = loader.reader.read()
sample_data_keys = sample_data.keys()
_, sample_model_input = self.data_loader.get_sample_model_input()

# verify that "model_output" and "is_valid" columns are present
if "model_output" not in pre_inf_results_df.columns or "is_valid" not in pre_inf_results_df.columns:
raise ValueError("Columns 'model_output' and 'is_valid' are required in the resume_from file.")

# check if remaining columns match those in current data loader
pre_inf_results_keys = pre_inf_results_df.columns.drop(["model_output", "is_valid"])
if set(sample_data_keys) != set(pre_inf_results_keys):
raise ValueError(
f"Columns in resume_from do not match the columns in the current data loader."
f"Current data loader columns: {sample_data_keys}. "
f"Provided inference results columns: {pre_inf_results_keys}."
# perform a sample inference call to get the model output keys and validate the resume_from contents
sample_response_dict = self.model.generate(*sample_model_input)
# check if the inference response dictionary contains the same keys as the resume_from file
if set(sample_response_dict.keys()) != set(pre_inf_results_df.columns):
logging.warn(
f"Columns in resume_from file do not match the current inference response. "
f"Current inference response keys: {sample_response_dict.keys()}. "
f"Resume_from file columns: {pre_inf_results_df.columns}."
)

# find the last uid that was inferenced
last_uid = pre_inf_results_df["uid"].astype(int).max()
logging.info(f"Last uid inferenced: {last_uid}")
return pre_inf_results_df, last_uid

def validate_response_dict(self, response_dict):
# Validate that the response dictionary contains the required fields
# "model_output" and "is_valid" are mandatory fields to be returned by any model
if "model_output" not in response_dict or "is_valid" not in response_dict:
raise ValueError("Response dictionary must contain 'model_output' and 'is_valid' keys.")

def retrieve_exisiting_result(self, data, pre_inf_results_df):
"""Finds the previous result for the given data point from the pre_inf_results_df and returns it if it is valid
data: dict, data point to be inferenced
pre_inf_results_df: pd.DataFrame, previous inference results
"""
prev_results = pre_inf_results_df[pre_inf_results_df.uid == data["uid"]]
if prev_results.empty:
return None
prev_result_is_valid = bool(prev_results["is_valid"].values[0])
prev_model_output = prev_results["model_output"].values[0]

if prev_result_is_valid:
logging.info(f"Skipping inference for uid: {data['uid']}. Using previous results.")
try:
prev_model_tokens = prev_results["n_output_tokens"].values[0]
except KeyError:
logging.warn(
"Previous results do not contain 'n_output_tokens' column, setting to None for this data point."
)
prev_model_tokens = None
try:
prev_model_time = prev_results["response_time"].values[0]
except KeyError:
logging.warn(
"Previous results do not contain 'response_time' column, setting to None for this data point."
)
prev_model_time = None

data["model_output"], data["is_valid"], data["n_output_tokens"], data["response_time"] = (
prev_model_output,
prev_result_is_valid,
prev_model_tokens,
prev_model_time,
)
return data

def run(self):
if self.max_concurrent > 1:
asyncio.run(self._run_par())
else:
self._run()

def _run(self):
"""sequential inference"""
if self.resume_from:
pre_inf_results_df, last_uid = self.fetch_previous_inference_results()
with self.data_loader as loader:
with self.writer as writer:
for data, model_inputs in tqdm(loader, desc="Inference Progress:"):
# if resume_from file is provided and valid inference results
# for the current data point are present in it, use them.

if self.resume_from and (data["uid"] <= last_uid):
prev_results = pre_inf_results_df[pre_inf_results_df.uid == data["uid"]]
prev_result_is_valid = bool(prev_results["is_valid"].values[0])
prev_model_output = prev_results["model_output"].values[0]
if prev_result_is_valid:
logging.info(f"Skipping inference for uid: {data['uid']}. Using previous results.")
data["model_output"], data["is_valid"] = prev_model_output, prev_result_is_valid
writer.write(data)
prev_result = self.retrieve_exisiting_result(data, pre_inf_results_df)
if prev_result:
writer.write(prev_result)
continue
# generate text from model

# generate text from model (optionally at a limited rate)
if self.n_calls_per_min:
while len(self.call_times) >= self.n_calls_per_min:
# remove the oldest call time if it is older than the rate limit period
if time.time() - self.call_times[0] > self.period:
self.call_times.popleft()
else:
# rate limit is reached, wait for a second
time.sleep(1)
self.call_times.append(time.time())
response_dict = self.model.generate(*model_inputs)
# "model_output" and "is_valid" are mandatory fields by any inference component
if "model_output" not in response_dict or "is_valid" not in response_dict:
raise ValueError("Response dictionary must contain 'model_output' and 'is_valid' keys.")
self.validate_response_dict(response_dict)
# write results
data.update(response_dict)
writer.write(data)

async def run_in_excutor(self, model_inputs, executor):
"""Run model.generate in a ThreadPoolExecutor.
args:
model_inputs (tuple): inputs to the model.generate function.
executor (ThreadPoolExecutor): ThreadPoolExecutor instance.
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, self.model.generate, *model_inputs)

async def _run_par(self):
"""parallel inference"""
concurrent_inputs = []
concurrent_metadata = []
if self.resume_from:
pre_inf_results_df, last_uid = self.fetch_previous_inference_results()
with self.data_loader as loader:
with self.writer as writer:
for data, model_inputs in tqdm(loader, desc="Inference Progress:"):
if self.resume_from and (data["uid"] <= last_uid):
prev_result = self.retrieve_exisiting_result(data, pre_inf_results_df)
if prev_result:
writer.write(prev_result)
continue

# if batch is ready for concurrent inference
elif len(concurrent_inputs) >= self.max_concurrent:
with ThreadPoolExecutor() as executor:
await self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor)
concurrent_inputs = []
concurrent_metadata = []
# add data to batch for concurrent inference
concurrent_inputs.append(model_inputs)
concurrent_metadata.append(data)
# if data loader is exhausted but there are remaining data points that did not form a full batch
if concurrent_inputs:
with ThreadPoolExecutor() as executor:
await self.run_batch(concurrent_inputs, concurrent_metadata, writer, executor)

async def run_batch(self, concurrent_inputs, concurrent_metadata, writer, executor):
"""Run a batch of inferences concurrently using ThreadPoolExecutor.
args:
concurrent_inputs (list): list of inputs to the model.generate function.
concurrent_metadata (list): list of metadata corresponding to the inputs.
writer (JsonLinesWriter): JsonLinesWriter instance to write the results.
executor (ThreadPoolExecutor): ThreadPoolExecutor instance.
"""
tasks = [asyncio.create_task(self.run_in_excutor(input_data, executor)) for input_data in concurrent_inputs]
results = await asyncio.gather(*tasks)
for i in range(len(concurrent_inputs)):
data, response_dict = concurrent_metadata[i], results[i]
self.validate_response_dict(response_dict)
# prepare results for writing
data.update(response_dict)
writer.write(data)
Loading
Loading