Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export dacapo #1

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions dcc/model_export/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Used to export dacapo model

You can use the interactive script to export the model.:
```bash
dcc_dacapo
```

The exported folder will contain the following files:

```
<run_name>/
├── model.omnx
├── model.pt
├── model.ts
├── README.md
└── metadata.json
```

The `metadata.json` file contains the following model metadata structure:

```json
{
"model_name": "model_name",
"model_type": "UNet",
"framework": "Dacapo",
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 2,
"iteration": 1000,
"input_voxel_size": [8, 8, 8],
"output_voxel_size": [8, 8, 8],
"channels_names": ["CT", "PET"],
"input_shape": [96, 96, 96],
"output_shape": [96, 96, 96],
"author": "author",
"description": "description",
"version": "1.0.0"
}
```

| Attribute | Type | Description | Example |
|--------------------|---------------------|--------------------------------------------------|-------------------------------|
| model_name | Optional[str] | | |
| model_type | Optional[str] | UNet or DenseNet121 | |
| framework | Optional[str] | Dacapo or PyTorch | |
| spatial_dims | Optional[int] | 2 or 3 | |
| in_channels | Optional[int] | | |
| out_channels | Optional[int] | | |
| iteration | Optional[int] | | |
| input_voxel_size | Optional[List[int]] | Comma-separated values | 8,8,8 |
| output_voxel_size | Optional[List[int]] | Comma-separated values | 8,8,8 |
| channels_names | Optional[List[str]] | Comma-separated values | 'CT, PET' |
| input_shape | Optional[List[int]] | Comma-separated values | 96,96,96 |
| output_shape | Optional[List[int]] | Comma-separated values | 96,96,96 |
| author | Optional[str] | | |
| description | Optional[str] | | |
| version | Optional[str] | | 1.0.0 |


# Saved models
## jrc_mus_liver
- 8nm mito :
- v21_mito_attention_finetuned_distances_8nm_mito_jrc_mus-livers_mito_8nm_attention-upsample-unet_default_one_label_1
- iteration: 345000

- 8nm peroxisome :
- v22_peroxisome_funetuning_best_v20_1e4_finetuned_distances_8nm_peroxisome_jrc_mus-livers_peroxisome_8nm_attention-upsample-unet_default_one_label_finetuning_0
- iteration: 45000


## Note: i had to checkout dacapo to version Feb 15 10:49:53 2024

commit 5371dedd3a008e438b601a227c4166273aab34bf
```bash
commit 5371dedd3a008e438b601a227c4166273aab34bf (HEAD)
Author: Marwan Zouinkhi <[email protected]>
Date: Thu Feb 15 10:49:53 2024 -0500

docstrings losses
```

dacapo.yaml
```yaml
mongo_db_name: dacapo_cellmap_v3_zouinkhim
runs_base_dir: "/nrs/cellmap/zouinkhim/crop_num_experiment_v2"
mongo_db_host: mongodb://cellmapAdmin:LUwWXkSY8N3AqCcw@cellmap-mongo:27017
type: "mongo"
```
13 changes: 10 additions & 3 deletions dcc/model_export/dacapo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dacapo.store.create_store import create_config_store, create_weights_store
from .generate_metadata import ModelMetadata, export_metadata, get_export_folder
from .export_model import export_torch_model
import os
from funlib.geometry import Coordinate


def export_dacapo_model():
Expand All @@ -19,19 +21,24 @@ def get_dacapo_infos(run_name: str, iteration: int):
weights = weights_store.retrieve_weights(run_name, iteration)
run.model.load_state_dict(weights.model)

input_scale = Coordinate(8,8,8)
output_scale = run.model.scale(input_scale)


metadata = ModelMetadata(
model_name=run_name,
model_type=run.model.architecture.__name__,
model_type=run.model.architecture.__class__.__name__,
framework="dacapo/torch",
spatial_dims=run.model.input_shape,
in_channels=run.model.num_in_channels,
out_channels=run.model.num_out_channels,
channels_names=run.task.channels,
input_shape=run.model.input_shape,
output_shape=run.model.output_shape,
iteration=iteration,
input_voxel_size=input_scale,
output_voxel_size=output_scale,
)
input_shape = (1, run.model.num_in_channels, *run.model.input_shape)

export_metadata(metadata)
export_torch_model(run.model, input_shape, get_export_folder())
export_torch_model(run.model, input_shape, os.path.join(get_export_folder(), run_name))
8 changes: 8 additions & 0 deletions dcc/model_export/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@

def export_torch_model(model, input_shape, folder_result):
model.eval()
print(f"Exporting model to {folder_result}")
pt_file = os.path.join(folder_result, "model.pt")
onnx_file = os.path.join(folder_result, "model.onnx")
ts_file = os.path.join(folder_result, "model.ts")

# Export to TorchScript
torch.save(model, ts_file)
print(f"Model saved to {ts_file}")

dummy_input = torch.rand(input_shape)
scripted_model = torch.jit.trace(model, dummy_input)
scripted_model.save(pt_file)
print(f"Model saved to {pt_file}")

# Export to ONNX
torch.onnx.export(
Expand All @@ -25,3 +32,4 @@ def export_torch_model(model, input_shape, folder_result):
input_names=["input"],
output_names=["output"],
)
print(f"Model saved to {onnx_file}")
23 changes: 20 additions & 3 deletions dcc/model_export/generate_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ class ModelMetadata(BaseModel):
spatial_dims: Optional[int] = Field(None, description="2 or 3")
in_channels: Optional[int] = None
out_channels: Optional[int] = None
iteration: Optional[int] = None
input_voxel_size: Optional[List[int]] = Field(
None, description="Comma-separated values, e.g., 8,8,8"
)
output_voxel_size: Optional[List[int]] = Field(
None, description="Comma-separated values, e.g., 8,8,8"
)
channels_names: Optional[List[str]] = Field(
None, description="Comma-separated values, e.g., 'CT, PET'"
)
Expand All @@ -39,6 +46,7 @@ class ModelMetadata(BaseModel):
def generate_readme(metadata: ModelMetadata):
readme_content = f"""
# {metadata.model_name} Model
iteration: {metadata.iteration}

## Description
{metadata.description}
Expand All @@ -62,12 +70,19 @@ def generate_readme(metadata: ModelMetadata):
return readme_content


def export_metadata(metadata: ModelMetadata):
prompt_for_missing_fields(metadata)
def export_metadata(metadata: ModelMetadata, overwrite: bool = False):

export_folder = get_export_folder()
result_folder = os.path.join(export_folder, metadata.model_name)
if os.path.exists(result_folder) and not overwrite:
result = click.confirm(
f"Folder {result_folder} already exists. Do you want to overwrite it?",
)
if not result:
return
metadata = prompt_for_missing_fields(metadata)
os.makedirs(result_folder, exist_ok=True)
output_file = os.path.join(result_folder, f"{metadata.model_name}_metadata.json")
output_file = os.path.join(result_folder, "metadata.json")
with open(output_file, "w") as f:
json.dump(metadata.dict(), f, indent=4)
click.echo(f"Metadata saved to {output_file}")
Expand Down Expand Up @@ -97,6 +112,8 @@ def prompt_for_missing_fields(metadata: ModelMetadata):
value = click.prompt(prompt_text, type=str)
setattr(metadata, field_name, value)

return metadata


@click.command()
@click.option("--model_name", prompt="Enter model name", type=str)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ packages = ["dcc"]

[project.scripts]
dacapo = "dcc.cli:cli"
dcc_dacapo = "dcc.model_export.export_model:export_dacapo_model"
dcc_dacapo = "dcc.model_export.dacapo_model:export_dacapo_model"

# https://github.com/charliermarsh/ruff
[tool.ruff]
Expand Down