-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathspatial_reasoning.py
141 lines (123 loc) · 5.66 KB
/
spatial_reasoning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
from eureka_ml_insights.configs.experiment_config import ExperimentConfig
from eureka_ml_insights.core import EvalReporting, Inference, PromptProcessing
from eureka_ml_insights.data_utils import (
AddColumnAndData,
ASTEvalTransform,
HFDataReader,
MMDataLoader,
ColumnRename,
DataReader,
PrependStringTransform,
SequenceTransform,
)
from eureka_ml_insights.data_utils.spatial_utils import (
LowerCaseNoPunctuationConvertNumbers,
)
from eureka_ml_insights.metrics import (
CountAggregator,
SpatialAndLayoutReasoningMetric,
)
from eureka_ml_insights.configs import (
AggregatorConfig,
DataSetConfig,
EvalReportingConfig,
InferenceConfig,
MetricConfig,
PipelineConfig,
PromptProcessingConfig,
)
from .common import LOCAL_DATA_PIPELINE
"""This file contains example user defined configuration classes for the spatial reasoning task.
In order to define a new configuration, a new class must be created that directly or indirectly
inherits from ExperimentConfig and the configure_pipeline method should be implemented.
You can inherit from one of the existing user defined classes below and override the necessary
attributes to reduce the amount of code you need to write.
The user defined configuration classes are used to define your desired *pipeline* that can include
any number of *component*s. Find *component* options in the core module.
Pass the name of the class to the main.py script to run the pipeline.
"""
class SPATIAL_REASONING_PAIRS_PIPELINE(ExperimentConfig):
"""
This defines an ExperimentConfig pipeline for the spatial reasoning dataset, pairs condition.
There is no model_config by default and the model config must be passed in via command lime.
"""
def configure_pipeline(self, model_config, resume_from=None):
# Configure the data processing component.
self.data_processing_comp = PromptProcessingConfig(
component_type=PromptProcessing,
data_reader_config=DataSetConfig(
HFDataReader,
{
"path": "microsoft/IMAGE_UNDERSTANDING",
"split": "val",
"tasks": "spatial_reasoning_lrtb_pairs",
},
),
output_dir=os.path.join(self.log_dir, "data_processing_output"),
)
# Configure the inference component
self.inference_comp = InferenceConfig(
component_type=Inference,
model_config=model_config,
data_loader_config=DataSetConfig(
MMDataLoader,
{
"path": os.path.join(self.data_processing_comp.output_dir, "transformed_data.jsonl"),
},
),
output_dir=os.path.join(self.log_dir, "inference_result"),
resume_from=resume_from,
)
# Configure the evaluation and reporting component.
self.evalreporting_comp = EvalReportingConfig(
component_type=EvalReporting,
data_reader_config=DataSetConfig(
DataReader,
{
"path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"),
"format": ".jsonl",
"transform": SequenceTransform(
[
AddColumnAndData("target_options", "['left', 'right', 'above', 'below']"),
ASTEvalTransform(columns=["target_options"]),
LowerCaseNoPunctuationConvertNumbers(
columns=["ground_truth", "model_output", "target_options"]
),
]
),
},
),
metric_config=MetricConfig(SpatialAndLayoutReasoningMetric),
aggregator_configs=[
AggregatorConfig(
CountAggregator, {"column_names": ["SpatialAndLayoutReasoningMetric_result"], "normalize": True}
),
AggregatorConfig(
CountAggregator,
{"column_names": ["SpatialAndLayoutReasoningMetric_result"], "group_by": "ground_truth"},
),
],
output_dir=os.path.join(self.log_dir, "eval_report"),
)
# Configure the pipeline
return PipelineConfig([self.data_processing_comp, self.inference_comp, self.evalreporting_comp], self.log_dir)
class SPATIAL_REASONING_SINGLE_PIPELINE(SPATIAL_REASONING_PAIRS_PIPELINE):
"""This class extends SPATIAL_REASONING_PAIRS_PIPELINE to use the single object condition."""
def configure_pipeline(self, model_config, resume_from=None):
config = super().configure_pipeline(model_config=model_config, resume_from=resume_from)
self.data_processing_comp.data_reader_config.init_args["tasks"] = (
"spatial_reasoning_lrtb_single"
)
self.evalreporting_comp.data_reader_config.init_args["transform"].transforms[
0
].data = "['left', 'right', 'top', 'bottom']"
return config
class SPATIAL_REASONING_PAIRS_LOCAL_PIPELINE(LOCAL_DATA_PIPELINE, SPATIAL_REASONING_PAIRS_PIPELINE):
def configure_pipeline(self, model_config, resume_from=None):
local_path = "/home/neel/data/spatial_understanding"
return super().configure_pipeline(model_config, resume_from, local_path)
class SPATIAL_REASONING_SINGLE_LOCAL_PIPELINE(LOCAL_DATA_PIPELINE, SPATIAL_REASONING_SINGLE_PIPELINE):
def configure_pipeline(self, model_config, resume_from=None):
local_path = "/home/neel/data/spatial_understanding"
return super().configure_pipeline(model_config, resume_from, local_path)