Skip to content

Commit

Permalink
ECS: create_task_set() now creates tasks (#8281)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored Nov 3, 2024
1 parent 0aa3edc commit 83383c9
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 14 deletions.
37 changes: 32 additions & 5 deletions moto/ecs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,7 @@ def __init__(
@property
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = self.gen_response_object()
response_object["taskDefinitionArn"] = response_object["arn"]
del response_object["arn"]
del response_object["tags"]
response_object["taskDefinitionArn"] = response_object.pop("arn")

if not response_object["requiresCompatibilities"]:
del response_object["requiresCompatibilities"]
Expand All @@ -266,7 +264,10 @@ def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
if not response_object["memory"]:
del response_object["memory"]

return response_object
return {
"taskDefinition": response_object,
"tags": response_object.get("tags", []),
}

@property
def physical_resource_id(self) -> str:
Expand Down Expand Up @@ -547,6 +548,7 @@ def __init__(
launch_type: Optional[str] = None,
service_registries: Optional[List[Dict[str, Any]]] = None,
platform_version: Optional[str] = None,
propagate_tags: str = "NONE",
):
self.cluster_name = cluster.name
self.cluster_arn = cluster.arn
Expand Down Expand Up @@ -594,6 +596,7 @@ def __init__(
]
else:
self.deployments = []
self.propagate_tags = propagate_tags

@property
def arn(self) -> str:
Expand Down Expand Up @@ -1609,6 +1612,7 @@ def create_service(
launch_type: Optional[str] = None,
service_registries: Optional[List[Dict[str, Any]]] = None,
platform_version: Optional[str] = None,
propagate_tags: str = "NONE",
) -> Service:
cluster = self._get_cluster(cluster_str)

Expand All @@ -1635,6 +1639,7 @@ def create_service(
backend=self,
service_registries=service_registries,
platform_version=platform_version,
propagate_tags=propagate_tags,
)
cluster_service_pair = f"{cluster.name}:{service_name}"
self.services[cluster_service_pair] = service
Expand Down Expand Up @@ -2181,13 +2186,35 @@ def create_task_set(
if not service_obj:
raise ServiceNotFoundException

task_set.task_definition = self.describe_task_definition(task_definition).arn
task_def_obj = self.describe_task_definition(task_definition)
task_set.task_definition = task_def_obj.arn
task_set.service_arn = service_obj.arn
task_set.cluster_arn = cluster_obj.arn

service_obj.task_sets.append(task_set)
# TODO: validate load balancers

if scale:
if scale.get("unit") == "PERCENT":
desired_count = service_obj.desired_count
nr_of_tasks = int(desired_count * (scale["value"] / 100))
all_tags = {}
if service_obj.propagate_tags == "TASK_DEFINITION":
all_tags.update({t["key"]: t["value"] for t in task_def_obj.tags})
if service_obj.propagate_tags == "SERVICE":
all_tags.update({t["key"]: t["value"] for t in service_obj.tags})
all_tags.update({t["key"]: t["value"] for t in (tags or [])})
self.run_task(
cluster_str=cluster_str,
task_definition_str=task_definition,
count=nr_of_tasks,
overrides=None,
started_by=self.account_id,
tags=[{"key": k, "value": v} for k, v in all_tags.items()],
launch_type=launch_type,
networking_configuration=network_configuration,
)

return task_set

def describe_task_sets(
Expand Down
16 changes: 7 additions & 9 deletions moto/ecs/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,18 @@ def register_task_definition(self) -> str:
pid_mode=pid_mode,
ephemeral_storage=ephemeral_storage,
)
return json.dumps({"taskDefinition": task_definition.response_object})
return json.dumps(task_definition.response_object)

def list_task_definitions(self) -> str:
family_prefix = self._get_param("familyPrefix")
task_definition_arns = self.ecs_backend.list_task_definitions(family_prefix)
return json.dumps(
{
"taskDefinitionArns": task_definition_arns
# 'nextToken': str(uuid.uuid4())
}
)
return json.dumps({"taskDefinitionArns": task_definition_arns})

def describe_task_definition(self) -> str:
task_definition_str = self._get_param("taskDefinition")
data = self.ecs_backend.describe_task_definition(task_definition_str)
resp: Dict[str, Any] = {"taskDefinition": data.response_object, "failures": []}
resp: Dict[str, Any] = data.response_object
resp["failures"] = []
if "TAGS" in self._get_param("include", []):
resp["tags"] = self.ecs_backend.list_tags_for_resource(data.arn)
return json.dumps(resp)
Expand All @@ -173,7 +169,7 @@ def deregister_task_definition(self) -> str:
task_definition = self.ecs_backend.deregister_task_definition(
task_definition_str
)
return json.dumps({"taskDefinition": task_definition.response_object})
return json.dumps(task_definition.response_object)

def run_task(self) -> str:
cluster_str = self._get_param("cluster", "default")
Expand Down Expand Up @@ -265,6 +261,7 @@ def create_service(self) -> str:
deployment_controller = self._get_param("deploymentController")
launch_type = self._get_param("launchType")
platform_version = self._get_param("platformVersion")
propagate_tags = self._get_param("propagateTags") or "NONE"
service = self.ecs_backend.create_service(
cluster_str,
service_name,
Expand All @@ -277,6 +274,7 @@ def create_service(self) -> str:
launch_type,
service_registries=service_registries,
platform_version=platform_version,
propagate_tags=propagate_tags,
)
return json.dumps({"service": service.response_object})

Expand Down
68 changes: 68 additions & 0 deletions tests/test_ecs/test_ecs_tasksets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import json

import boto3
import pytest
from botocore.exceptions import ClientError

from moto import mock_aws
from moto.ec2 import utils as ec2_utils
from tests import EXAMPLE_AMI_ID

cluster_name = "test_ecs_cluster"
service_name = "test_ecs_service"
Expand Down Expand Up @@ -376,6 +380,70 @@ def test_create_task_sets_with_tags():
assert {"key": "k3", "value": "v3"} in resp["tags"]


@mock_aws
def test_create_task_set_and_list_tasks():
tags = [{"key": "key-1", "value": "value-1"}, {"key": "key-2", "value": "value-1"}]

ec2 = boto3.resource("ec2", "us-east-1")
ecs = boto3.client("ecs", "us-east-1")

cluster_a = ecs.create_cluster(clusterName="test-cluster-a")["cluster"]

test_instance = ec2.create_instances(
ImageId=EXAMPLE_AMI_ID, MinCount=1, MaxCount=1
)[0]
instance_id_document = json.dumps(
ec2_utils.generate_instance_identity_document(test_instance)
)
ecs.register_container_instance(
cluster="test-cluster-a", instanceIdentityDocument=instance_id_document
)

response = ecs.register_task_definition(
family="test-family",
containerDefinitions=[
{"name": "test-container-def", "image": "foo", "memory": 256}
],
tags=tags,
)

task_def = response["taskDefinition"]

service_response = ecs.create_service(
cluster=cluster_a["clusterArn"],
serviceName="test-service",
taskDefinition=task_def["taskDefinitionArn"],
desiredCount=1,
deploymentController={"type": "EXTERNAL"},
propagateTags="TASK_DEFINITION",
)
service = service_response["service"]

ecs.create_task_set(
service=service["serviceName"],
cluster=cluster_a["clusterName"],
externalId="test-ext-id",
taskDefinition=task_def["taskDefinitionArn"],
scale={"unit": "PERCENT", "value": 100},
)

list_tasks = ecs.list_tasks(
cluster=cluster_a["clusterName"], serviceName=service["serviceName"]
)

task_arns = list_tasks["taskArns"]
assert len(task_arns) == 1

describe_tasks = ecs.describe_tasks(
cluster=cluster_a["clusterName"], tasks=task_arns, include=["TAGS"]
)

tasks = describe_tasks["tasks"]

assert len(tasks) == 1
assert tasks[0]["tags"] == tags


def create_task_def(client):
client.register_task_definition(
family=task_def_name,
Expand Down

0 comments on commit 83383c9

Please sign in to comment.