Skip to content

Commit

Permalink
Merge pull request #620 from roboflow/camera-focus-block
Browse files Browse the repository at this point in the history
Adding block to assist in setting camera focus
  • Loading branch information
PawelPeczek-Roboflow authored Aug 30, 2024
2 parents b8d9d46 + 4cff23b commit 547384a
Show file tree
Hide file tree
Showing 16 changed files with 378 additions and 44 deletions.
Empty file.
162 changes: 162 additions & 0 deletions inference/core/workflows/core_steps/classical_cv/camera_focus/v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from typing import List, Literal, Optional, Tuple, Type, Union

import cv2
import numpy as np
from pydantic import AliasChoices, ConfigDict, Field

from inference.core.workflows.core_steps.visualizations.common.base import (
OUTPUT_IMAGE_KEY,
)
from inference.core.workflows.execution_engine.entities.base import (
OutputDefinition,
WorkflowImageData,
)
from inference.core.workflows.execution_engine.entities.types import (
BATCH_OF_IMAGES_KIND,
FLOAT_KIND,
StepOutputImageSelector,
WorkflowImageSelector,
)
from inference.core.workflows.prototypes.block import (
BlockResult,
WorkflowBlock,
WorkflowBlockManifest,
)

SHORT_DESCRIPTION: str = "Helps focus a camera by providing a focus measure."
LONG_DESCRIPTION: str = """
This block calculate the Brenner function score which is a measure of the texture in the image.
An in-focus image has a high Brenner function score, and contains texture at a smaller scale than
an out-of-focus image. Conversely, an out-of-focus image has a low Brenner function score, and
does not contain small-scale texture.
"""


class CameraFocusManifest(WorkflowBlockManifest):
type: Literal["roboflow_core/camera_focus@v1"]
model_config = ConfigDict(
json_schema_extra={
"name": "Camera Focus",
"version": "v1",
"short_description": SHORT_DESCRIPTION,
"long_description": LONG_DESCRIPTION,
"license": "Apache-2.0",
"block_type": "classical_computer_vision",
}
)

image: Union[WorkflowImageSelector, StepOutputImageSelector] = Field(
title="Input Image",
description="The input image for this step.",
examples=["$inputs.image", "$steps.cropping.crops"],
validation_alias=AliasChoices("image", "images"),
)

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return [
OutputDefinition(
name=OUTPUT_IMAGE_KEY,
kind=[
BATCH_OF_IMAGES_KIND,
],
),
OutputDefinition(
name="focus_measure",
kind=[
FLOAT_KIND,
],
),
]

@classmethod
def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.0.0,<2.0.0"


class CameraFocusBlockV1(WorkflowBlock):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@classmethod
def get_manifest(cls) -> Type[CameraFocusManifest]:
return CameraFocusManifest

def run(self, image: WorkflowImageData, *args, **kwargs) -> BlockResult:
# Calculate the Brenner measure
brenner_image, brenner_value = calculate_brenner_measure(image.numpy_image)

output = WorkflowImageData(
parent_metadata=image.parent_metadata,
workflow_root_ancestor_metadata=image.workflow_root_ancestor_metadata,
numpy_image=brenner_image,
)

return {
OUTPUT_IMAGE_KEY: output,
"focus_measure": brenner_value,
}


def calculate_brenner_measure(
input_image: np.ndarray,
text_color: Tuple[int, int, int] = (255, 255, 255),
text_thickness: int = 2,
) -> Tuple[np.ndarray, float]:
"""
Brenner's focus measure.
Parameters
----------
input_image : np.ndarray
The input image in grayscale.
text_color : Tuple[int, int, int], optional
The color of the text displaying the Brenner value, in BGR format. Default is white (255, 255, 255).
text_thickness : int, optional
The thickness of the text displaying the Brenner value. Default is 2.
Returns
-------
Tuple[np.ndarray, float]
The Brenner image and the Brenner value.
"""
# Convert image to grayscale if it has 3 channels
if len(input_image.shape) == 3:
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY)

# Convert image to 16-bit integer format
converted_image = input_image.astype(np.int16)

# Get the dimensions of the image
height, width = converted_image.shape

# Initialize two matrices for horizontal and vertical focus measures
horizontal_diff = np.zeros((height, width))
vertical_diff = np.zeros((height, width))

# Calculate horizontal and vertical focus measures
horizontal_diff[:, : width - 2] = np.clip(
converted_image[:, 2:] - converted_image[:, :-2], 0, None
)
vertical_diff[: height - 2, :] = np.clip(
converted_image[2:, :] - converted_image[:-2, :], 0, None
)

# Calculate final focus measure
focus_measure = np.max((horizontal_diff, vertical_diff), axis=0) ** 2

# Convert focus measure matrix to 8-bit for visualization
focus_measure_image = ((focus_measure / focus_measure.max()) * 255).astype(np.uint8)

# Display the Brenner value on the top left of the image
cv2.putText(
focus_measure_image,
f"Focus value: {focus_measure.mean():.2f}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
text_color,
text_thickness,
)

return focus_measure_image, focus_measure.mean()
4 changes: 4 additions & 0 deletions inference/core/workflows/core_steps/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from inference.core.cache import cache
from inference.core.env import API_KEY, WORKFLOWS_STEP_EXECUTION_MODE
from inference.core.workflows.core_steps.classical_cv.camera_focus.v1 import (
CameraFocusBlockV1,
)
from inference.core.workflows.core_steps.classical_cv.contours.v1 import (
ImageContoursDetectionBlockV1,
)
Expand Down Expand Up @@ -287,6 +290,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
ImageThresholdBlockV1,
ImageContoursDetectionBlockV1,
ClipComparisonBlockV2,
CameraFocusBlockV1,
]


Expand Down
9 changes: 5 additions & 4 deletions tests/inference/models_predictions_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import json
import os.path
import shutil
import zipfile
from typing import Generator
from typing import Dict, Generator

import cv2
import numpy as np
import pytest
import requests
import json
from typing import Dict

from inference.core.env import MODEL_CACHE_DIR

Expand All @@ -29,12 +28,12 @@
)



@pytest.fixture(scope="function")
def sam2_multipolygon_response() -> Dict:
with open(SAM2_MULTI_POLY_RESPONSE_PATH) as f:
return json.load(f)


@pytest.fixture(scope="function")
def example_image() -> np.ndarray:
return cv2.imread(EXAMPLE_IMAGE_PATH)
Expand Down Expand Up @@ -197,6 +196,7 @@ def sam2_small_model() -> Generator[str, None, None]:
yield model_id
shutil.rmtree(model_cache_dir)


@pytest.fixture(scope="function")
def sam2_tiny_model() -> Generator[str, None, None]:
model_id = "sam2/hiera_tiny"
Expand All @@ -217,6 +217,7 @@ def sam2_small_truck_logits() -> Generator[np.ndarray, None, None]:
def sam2_small_truck_mask_from_cached_logits() -> Generator[np.ndarray, None, None]:
yield np.load(SAM2_TRUCK_MASK_FROM_CACHE)


def fetch_and_place_model_in_cache(
model_id: str,
model_package_url: str,
Expand Down
31 changes: 16 additions & 15 deletions tests/inference/models_predictions_tests/test_sam2.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
import numpy as np
import pytest
import torch
import json
import requests
from copy import deepcopy
from io import BytesIO
from typing import Dict

import numpy as np
import pytest
import requests
import torch
from PIL import Image
from io import BytesIO
from inference.core.entities.requests.sam2 import Sam2PromptSet
from inference.models.sam2 import SegmentAnything2
from inference.models.sam2.segment_anything2 import (
hash_prompt_set,
maybe_load_low_res_logits_from_cache,

from inference.core.entities.requests.sam2 import Sam2PromptSet, Sam2SegmentationRequest
from inference.core.entities.responses.sam2 import Sam2SegmentationPrediction
from inference.core.workflows.core_steps.common.utils import (
convert_inference_detections_batch_to_sv_detections,
)
from inference.core.workflows.core_steps.models.foundation.segment_anything2.v1 import (
convert_sam2_segmentation_response_to_inference_instances_seg_response,
)
from inference.core.workflows.core_steps.common.utils import (
convert_inference_detections_batch_to_sv_detections,
from inference.models.sam2 import SegmentAnything2
from inference.models.sam2.segment_anything2 import (
hash_prompt_set,
maybe_load_low_res_logits_from_cache,
)
from inference.core.entities.responses.sam2 import Sam2SegmentationPrediction
from inference.core.entities.requests.sam2 import Sam2SegmentationRequest
from typing import Dict


@pytest.mark.slow
Expand Down Expand Up @@ -236,6 +236,7 @@ def test_sam2_multi_poly(sam2_tiny_model: str, sam2_multipolygon_response: Dict)
except Exception as e:
raise e


def test_model_clears_cache_properly(sam2_small_model, truck_image):
cache_size = 2
model = SegmentAnything2(
Expand Down
38 changes: 23 additions & 15 deletions tests/inference/unit_tests/core/utils/test_sqlite_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
def test_count_empty_table():
# given
conn = sqlite3.connect(":memory:")
q = SQLiteWrapper(db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn)
q = SQLiteWrapper(
db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn
)

# then
assert q.count(connection=conn) == 0
Expand All @@ -18,7 +20,9 @@ def test_count_empty_table():
def test_insert():
# given
conn = sqlite3.connect(":memory:")
q = SQLiteWrapper(db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn)
q = SQLiteWrapper(
db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn
)

# when
q.insert(values={"col1": "lorem"}, connection=conn)
Expand All @@ -31,7 +35,9 @@ def test_insert():
def test_insert_incorrect_columns():
# given
conn = sqlite3.connect(":memory:")
q = SQLiteWrapper(db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn)
q = SQLiteWrapper(
db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn
)

with pytest.raises(ValueError):
q.insert(values={"col2": "lorem"}, connection=conn)
Expand All @@ -42,25 +48,26 @@ def test_insert_incorrect_columns():
def test_select_no_limit():
# given
conn = sqlite3.connect(":memory:")
q = SQLiteWrapper(db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn)
q = SQLiteWrapper(
db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn
)

# when
q.insert(values={"col1": "lorem"}, connection=conn)
q.insert(values={"col1": "ipsum"}, connection=conn)
values = q.select(connection=conn)

# then
assert values == [
{"id": 1, "col1": "lorem"},
{"id": 2, "col1": "ipsum"}
]
assert values == [{"id": 1, "col1": "lorem"}, {"id": 2, "col1": "ipsum"}]
conn.close()


def test_select_limit():
# given
conn = sqlite3.connect(":memory:")
q = SQLiteWrapper(db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn)
q = SQLiteWrapper(
db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn
)

# when
q.insert(values={"col1": "lorem"}, connection=conn)
Expand All @@ -77,26 +84,27 @@ def test_select_limit():
def test_flush_no_limit():
# given
conn = sqlite3.connect(":memory:")
q = SQLiteWrapper(db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn)
q = SQLiteWrapper(
db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn
)

# when
q.insert(values={"col1": "lorem"}, connection=conn)
q.insert(values={"col1": "ipsum"}, connection=conn)
values = q.flush(connection=conn)

# then
assert values == [
{"id": 1, "col1": "lorem"},
{"id": 2, "col1": "ipsum"}
]
assert values == [{"id": 1, "col1": "lorem"}, {"id": 2, "col1": "ipsum"}]
assert q.count(connection=conn) == 0
conn.close()


def test_flush_limit():
# given
conn = sqlite3.connect(":memory:")
q = SQLiteWrapper(db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn)
q = SQLiteWrapper(
db_file_path="", table_name="test", columns={"col1": "TEXT"}, connection=conn
)

# when
q.insert(values={"col1": "lorem"}, connection=conn)
Expand Down
Loading

0 comments on commit 547384a

Please sign in to comment.