Skip to content

Commit

Permalink
Merge pull request #2913 from bramstroker/feat/smart-switch-light-domain
Browse files Browse the repository at this point in the history
Support light entities to be used with smart_switch profiles
  • Loading branch information
bramstroker authored Jan 10, 2025
2 parents 5965842 + 361f30a commit ee2085a
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 21 deletions.
10 changes: 5 additions & 5 deletions custom_components/powercalc/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
from .group_include.include import find_entities
from .power_profile.factory import get_power_profile
from .power_profile.library import ModelInfo, ProfileLibrary
from .power_profile.power_profile import DEVICE_TYPE_DOMAIN, DeviceType, PowerProfile, get_entity_device_types
from .power_profile.power_profile import DOMAIN_DEVICE_TYPE_MAPPING, SUPPORTED_DOMAINS, DeviceType, PowerProfile
from .sensors.daily_energy import DEFAULT_DAILY_UPDATE_FREQUENCY
from .sensors.power import PowerSensor
from .strategy.factory import PowerCalculatorStrategyFactory
Expand Down Expand Up @@ -791,7 +791,7 @@ def create_source_entity_selector(
"""Create the entity selector for the source entity."""
if self.is_library_flow:
return selector.EntitySelector(
selector.EntitySelectorConfig(domain=list(DEVICE_TYPE_DOMAIN.values())),
selector.EntitySelectorConfig(domain=list(SUPPORTED_DOMAINS)),
)
return selector.EntitySelector()

Expand Down Expand Up @@ -944,7 +944,7 @@ async def async_step_manufacturer(
async def _create_schema() -> vol.Schema:
"""Create manufacturer schema."""
library = await ProfileLibrary.factory(self.hass)
device_types = get_entity_device_types(self.source_entity.domain, self.source_entity.entity_entry) if self.source_entity else None
device_types = DOMAIN_DEVICE_TYPE_MAPPING.get(self.source_entity.domain, set()) if self.source_entity else None
manufacturers = [
selector.SelectOptionDict(value=manufacturer, label=manufacturer)
for manufacturer in await library.get_manufacturer_listing(device_types)
Expand Down Expand Up @@ -992,7 +992,7 @@ async def _create_schema() -> vol.Schema:
"""Create model schema."""
manufacturer = str(self.sensor_config.get(CONF_MANUFACTURER))
library = await ProfileLibrary.factory(self.hass)
device_types = get_entity_device_types(self.source_entity.domain, self.source_entity.entity_entry) if self.source_entity else None
device_types = DOMAIN_DEVICE_TYPE_MAPPING.get(self.source_entity.domain, set()) if self.source_entity else None
models = [selector.SelectOptionDict(value=model, label=model) for model in await library.get_model_listing(manufacturer, device_types)]
return vol.Schema(
{
Expand Down Expand Up @@ -1680,7 +1680,7 @@ def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow."""
super().__init__()
if AwesomeVersion(HAVERSION) < "2024.12":
self.config_entry = config_entry
self.config_entry = config_entry # pragma: no cover
self.sensor_config = dict(config_entry.data)
self.sensor_type: SensorType = self.sensor_config.get(CONF_SENSOR_TYPE) or SensorType.VIRTUAL_POWER
self.source_entity_id: str = self.sensor_config.get(CONF_ENTITY_ID) # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions custom_components/powercalc/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .helpers import get_or_create_unique_id
from .power_profile.factory import get_power_profile
from .power_profile.library import ModelInfo, ProfileLibrary
from .power_profile.power_profile import DEVICE_TYPE_DOMAIN, DiscoveryBy, PowerProfile
from .power_profile.power_profile import SUPPORTED_DOMAINS, DiscoveryBy, PowerProfile

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -281,7 +281,7 @@ def _check_already_configured(entity: er.RegistryEntry) -> bool:
LambdaFilter(lambda entity: entity.device_id is None),
LambdaFilter(lambda entity: entity.platform == "mqtt" and "segment" in entity.entity_id),
LambdaFilter(lambda entity: entity.platform == "powercalc"),
NotFilter(DomainFilter(DEVICE_TYPE_DOMAIN.values())),
NotFilter(DomainFilter(SUPPORTED_DOMAINS)),
],
FilterOperator.OR,
)
Expand Down
4 changes: 2 additions & 2 deletions custom_components/powercalc/group_include/include.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
DOMAIN,
)
from custom_components.powercalc.discovery import get_power_profile_by_source_entity
from custom_components.powercalc.power_profile.power_profile import DEVICE_TYPE_DOMAIN
from custom_components.powercalc.power_profile.power_profile import SUPPORTED_DOMAINS
from custom_components.powercalc.sensors.energy import RealEnergySensor
from custom_components.powercalc.sensors.power import RealPowerSensor

Expand Down Expand Up @@ -68,7 +68,7 @@ async def find_entities(
def _build_filter(entity_filter: EntityFilter | None) -> EntityFilter:
base_filter = CompositeFilter(
[
DomainFilter(DEVICE_TYPE_DOMAIN.values()),
DomainFilter(SUPPORTED_DOMAINS),
LambdaFilter(lambda entity: entity.platform != "utility_meter"),
],
)
Expand Down
27 changes: 15 additions & 12 deletions custom_components/powercalc/power_profile/power_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, NamedTuple, Protocol, cast
Expand Down Expand Up @@ -68,32 +69,34 @@ class CustomField:
description: str | None = None


DEVICE_TYPE_DOMAIN = {
DEVICE_TYPE_DOMAIN: dict[DeviceType, str | set[str]] = {
DeviceType.CAMERA: CAMERA_DOMAIN,
DeviceType.COVER: COVER_DOMAIN,
DeviceType.GENERIC_IOT: SENSOR_DOMAIN,
DeviceType.LIGHT: LIGHT_DOMAIN,
DeviceType.POWER_METER: SENSOR_DOMAIN,
DeviceType.SMART_DIMMER: LIGHT_DOMAIN,
DeviceType.SMART_SWITCH: SWITCH_DOMAIN,
DeviceType.SMART_SWITCH: {SWITCH_DOMAIN, LIGHT_DOMAIN},
DeviceType.SMART_SPEAKER: MEDIA_PLAYER_DOMAIN,
DeviceType.NETWORK: BINARY_SENSOR_DOMAIN,
DeviceType.PRINTER: SENSOR_DOMAIN,
DeviceType.VACUUM_ROBOT: VACUUM_DOMAIN,
}

DOMAIN_TO_DEVICE_TYPES = defaultdict(set)
for device_type, domain in DEVICE_TYPE_DOMAIN.items():
DOMAIN_TO_DEVICE_TYPES[domain].add(device_type)
SUPPORTED_DOMAINS: set[str] = {domain for domains in DEVICE_TYPE_DOMAIN.values() for domain in (domains if isinstance(domains, set) else {domains})}


def get_entity_device_types(entity_domain: str, entity_entry: RegistryEntry | None) -> set[DeviceType]:
def _build_domain_device_type_mapping() -> Mapping[str, set[DeviceType]]:
"""Get the device types for a given entity domain."""
device_types = set(DOMAIN_TO_DEVICE_TYPES.get(entity_domain, {}))
# see https://github.com/bramstroker/homeassistant-powercalc/issues/1491
if entity_entry and entity_entry.platform in ["hue", "osramlightify"] and entity_domain == LIGHT_DOMAIN:
device_types.add(DeviceType.SMART_SWITCH)
return device_types
domain_to_device_type: defaultdict[str, set[DeviceType]] = defaultdict(set)
for device_type, domains in DEVICE_TYPE_DOMAIN.items():
domain_set = domains if isinstance(domains, set) else {domains}
for domain in domain_set:
domain_to_device_type[domain].add(device_type)
return domain_to_device_type


DOMAIN_DEVICE_TYPE_MAPPING: Mapping[str, set[DeviceType]] = _build_domain_device_type_mapping()


class PowerProfile:
Expand Down Expand Up @@ -378,7 +381,7 @@ def is_entity_domain_supported(self, entity_entry: RegistryEntry) -> bool:
if self.device_type == DeviceType.PRINTER and entity_entry.unit_of_measurement:
return False

return self.device_type in get_entity_device_types(domain, entity_entry)
return self.device_type in DOMAIN_DEVICE_TYPE_MAPPING[domain]


class SubProfileSelector:
Expand Down
11 changes: 11 additions & 0 deletions tests/power_profile/test_power_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,17 @@ async def test_vacuum_entity_domain_supported(hass: HomeAssistant) -> None:
)


async def test_light_domain_supported_for_smart_switch_device_type(hass: HomeAssistant) -> None:
library = await ProfileLibrary.factory(hass)
power_profile = await library.get_profile(
ModelInfo("dummy", "dummy"),
get_test_profile_dir("smart_switch"),
)
assert power_profile.is_entity_domain_supported(
SourceEntity("light.test", "test", "light"),
)


async def test_discovery_does_not_break_when_unknown_device_type(hass: HomeAssistant) -> None:
library = await ProfileLibrary.factory(hass)
power_profile = await library.get_profile(
Expand Down

0 comments on commit ee2085a

Please sign in to comment.