Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unknown args #77

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions eureka_ml_insights/core/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,21 @@

from .pipeline import Component
from .reserved_names import INFERENCE_RESERVED_NAMES

MINUTE = 60


class Inference(Component):
def __init__(self, model_config, data_config, output_dir, resume_from=None, new_columns=None, requests_per_minute=None, max_concurrent=1):

def __init__(
self,
model_config,
data_config,
output_dir,
resume_from=None,
new_columns=None,
requests_per_minute=None,
max_concurrent=1,
):
"""
Initialize the Inference component.
args:
Expand Down Expand Up @@ -62,13 +71,13 @@ def fetch_previous_inference_results(self):
# fetch previous results from the provided resume_from file
logging.info(f"Resuming inference from {self.resume_from}")
pre_inf_results_df = DataReader(self.resume_from, format=".jsonl").load_dataset()

# add new columns listed by the user to the previous inference results
if self.new_columns:
for col in self.new_columns:
if col not in pre_inf_results_df.columns:
pre_inf_results_df[col] = None

# validate the resume_from contents
with self.data_loader as loader:
_, sample_model_input = self.data_loader.get_sample_model_input()
Expand All @@ -80,13 +89,17 @@ def fetch_previous_inference_results(self):

# perform a sample inference call to get the model output keys and validate the resume_from contents
sample_response_dict = self.model.generate(*sample_model_input)
if not sample_response_dict["is_valid"]:
raise ValueError(
"Sample inference call for resume_from returned invalid results, please check the model configuration."
)
# check if the inference response dictionary contains the same keys as the resume_from file
eventual_keys = set(sample_response_dict.keys()) | set(sample_data_keys)

# in case of resuming from a file that was generated by an older version of the model,
# we let the discrepancy in the reserved keys slide and later set the missing keys to None
match_keys = set(pre_inf_results_df.columns) | set(INFERENCE_RESERVED_NAMES)
match_keys = set(pre_inf_results_df.columns) | set(INFERENCE_RESERVED_NAMES)

if set(eventual_keys) != match_keys:
diff = set(eventual_keys) ^ set(match_keys)
raise ValueError(
Expand Down Expand Up @@ -139,6 +152,11 @@ def retrieve_exisiting_result(self, data, pre_inf_results_df):
prev_model_tokens,
prev_model_time,
)
# add remaining pre_inf_results_df columns to the data point
for col in pre_inf_results_df.columns:
if col not in data:
data[col] = prev_results[col].values[0]

return data

def run(self):
Expand Down
17 changes: 15 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,23 @@
parser.add_argument(
"--resume_from", type=str, help="The path to the inference_result.jsonl to resume from.", default=None
)
args = parser.parse_args()
init_args = {}

# catch any unknown arguments
args, unknown_args = parser.parse_known_args()
if unknown_args:
# if every other unknown arg starts with "--", parse the unknown args as key-value pairs in a dict
if all(arg.startswith("--") for arg in unknown_args[::2]):
init_args.update(
{arg[len("--") :]: unknown_args[i + 1] for i, arg in enumerate(unknown_args) if i % 2 == 0}
)
logging.info(f"Unknown arguments: {init_args} will be sent to the experiment config class.")
# else, parse the unknown args as is ie. as a list
else:
init_args["unknown_args"] = unknown_args
logging.info(f"Unknown arguments: {unknown_args} will be sent as is to the experiment config class.")

experiment_config_class = args.exp_config
init_args = {}
if args.model_config:
try:
init_args["model_config"] = getattr(model_configs, args.model_config)
Expand Down
Loading