Skip to content
This repository has been archived by the owner on Jan 22, 2025. It is now read-only.

Add some typehints #252

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions d2go/runner/default_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from detectron2.engine import AMPTrainer, HookBase, hooks, SimpleTrainer
from detectron2.evaluation import (
COCOEvaluator,
DatasetEvaluator,
DatasetEvaluators,
inference_on_dataset,
LVISEvaluator,
Expand Down Expand Up @@ -150,7 +151,7 @@ def register(self, cfg):
pass

@staticmethod
def get_default_cfg():
def get_default_cfg() -> CfgNode:
"""
Override `get_default_cfg` for adding non common config.
"""
Expand Down Expand Up @@ -215,7 +216,7 @@ def register(self, cfg):
patch_d2_meta_arch()

@staticmethod
def get_default_cfg():
def get_default_cfg() -> CfgNode:
_C = super(Detectron2GoRunner, Detectron2GoRunner).get_default_cfg()
return get_default_cfg(_C)

Expand Down Expand Up @@ -520,7 +521,9 @@ def get_data_loader_vis_wrapper() -> Optional[Type[DataLoaderVisWrapper]]:
return DataLoaderVisWrapper

@staticmethod
def get_evaluator(cfg, dataset_name, output_folder):
def get_evaluator(
cfg: CfgNode, dataset_name: str, output_folder: str
) -> DatasetEvaluator:
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
if evaluator_type in ["coco", "coco_panoptic_seg"]:
# D2 is in the process of reducing the use of cfg.
Expand Down Expand Up @@ -624,7 +627,7 @@ def _add_rcnn_default_config(_C):

class GeneralizedRCNNRunner(Detectron2GoRunner):
@staticmethod
def get_default_cfg():
def get_default_cfg() -> CfgNode:
_C = super(GeneralizedRCNNRunner, GeneralizedRCNNRunner).get_default_cfg()
_add_rcnn_default_config(_C)
return _C
7 changes: 5 additions & 2 deletions d2go/runner/lightning_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from d2go.utils.ema_state import EMAState
from d2go.utils.misc import get_tensorboard_log_dir
from d2go.utils.visualization import VisualizationEvaluator
from detectron2.evaluation import DatasetEvaluator
from detectron2.solver import (
build_lr_scheduler as d2_build_lr_scheduler,
build_optimizer as d2_build_optimizer,
Expand Down Expand Up @@ -364,7 +365,9 @@ def _initialize(cfg: CfgNode):
pass

@staticmethod
def get_evaluator(cfg: CfgNode, dataset_name: str, output_folder: str):
def get_evaluator(
cfg: CfgNode, dataset_name: str, output_folder: str
) -> DatasetEvaluator:
return Detectron2GoRunner.get_evaluator(
cfg=cfg, dataset_name=dataset_name, output_folder=output_folder
)
Expand Down Expand Up @@ -494,5 +497,5 @@ def prepare_for_quant_convert(self) -> pl.LightningModule:

class GeneralizedRCNNTask(DefaultTask):
@classmethod
def get_default_cfg(cls):
def get_default_cfg(cls) -> CfgNode:
return GeneralizedRCNNRunner.get_default_cfg()