Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolves: Run model loading on separate thread #19

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="digestai",
version="1.1.1",
version="1.1.2",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
Expand Down
2 changes: 1 addition & 1 deletion src/digest/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class ProgressDialog(QProgressDialog):
"""A pop up window with a progress label that goes from 1 to 100"""

def __init__(self, label: str, num_steps: int, parent=None):
def __init__(self, label: str, num_steps: int = 0, parent=None):
"""
label: the text to be shown in the pop up dialog
num_steps: the total number of events the progress bar will load through
Expand Down
819 changes: 212 additions & 607 deletions src/digest/main.py

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions src/digest/model_class/digest_onnx_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.

# pylint: disable=no-name-in-module
import os
from typing import List, Dict, Optional, Tuple, cast
from PySide6.QtCore import QRunnable, Signal, Slot, QObject
from datetime import datetime
import importlib.metadata
from collections import OrderedDict
Expand Down Expand Up @@ -654,3 +656,44 @@ def save_text_report(self, filepath: str) -> None:
f_p.write("Output Tensor(s) Information:\n")
f_p.write(output_table.get_string())
f_p.write("\n\n")


class WorkerSignals(QObject):
completed = Signal(DigestOnnxModel)


class LoadDigestOnnxModelWorker(QRunnable):

def __init__(
self,
model_file_path: str,
model_name: str,
):
super().__init__()
self.signals = WorkerSignals()
self.tab_name = model_name
self.model_file_path = model_file_path
self.unique_id: Optional[str] = None

@Slot()
def run(self):
try:
model_proto = onnx_utils.load_onnx(
self.model_file_path, load_external_data=False
)
opt_model, _ = onnx_utils.optimize_onnx_model(model_proto)
except FileNotFoundError as e:
print(f"File not found: {e.filename}")

digest_model = DigestOnnxModel(
opt_model,
model_name=self.tab_name,
onnx_filepath=self.model_file_path,
)

self.unique_id = digest_model.unique_id

if not self.tab_name:
self.tab_name = digest_model.model_name

self.signals.completed.emit(digest_model)
25 changes: 25 additions & 0 deletions src/digest/model_class/digest_report_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import csv
import ast
import re
from PySide6.QtCore import QRunnable, Signal, Slot, QObject
from typing import Tuple, Optional, List, Dict, Any, Union
import yaml
from digest.model_class.digest_model import (
Expand Down Expand Up @@ -153,6 +154,30 @@ def save_text_report(self, filepath: str) -> None:
return


class WorkerSignals(QObject):
completed = Signal(DigestReportModel)


class LoadDigestReportModelWorker(QRunnable):

def __init__(
self,
model_file_path: str,
model_name: str,
):
super().__init__()
self.signals = WorkerSignals()
self.tab_name = model_name
self.model_file_path = model_file_path
self.unique_id: Optional[str] = None

@Slot()
def run(self):

digest_model = DigestReportModel(self.model_file_path)
self.signals.completed.emit(digest_model)


def validate_yaml(report_file_path: str) -> bool:
"""Check that the provided yaml file is indeed a Digest Report file."""
expected_keys = [
Expand Down
149 changes: 141 additions & 8 deletions src/digest/modelsummary.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.

import os

# pylint: disable=invalid-name
from typing import Optional, Union
import os
from datetime import datetime
from typing import Optional

# pylint: disable=no-name-in-module
from PySide6.QtWidgets import QWidget
from PySide6.QtWidgets import QWidget, QTableWidgetItem
from PySide6.QtGui import QMovie
from PySide6.QtCore import QSize

Expand All @@ -16,6 +16,7 @@
from digest.freeze_inputs import FreezeInputs
from digest.popup_window import PopupWindow
from digest.qt_utils import apply_dark_style_sheet
from digest.model_class.digest_model import SupportedModelTypes, DigestModel
from digest.model_class.digest_onnx_model import DigestOnnxModel
from digest.model_class.digest_report_model import DigestReportModel

Expand All @@ -25,20 +26,22 @@

class modelSummary(QWidget):

def __init__(
self, digest_model: Union[DigestOnnxModel, DigestReportModel], parent=None
):
# def __init__(
# self, digest_model: Union[DigestOnnxModel, DigestReportModel], parent=None
# ):
Comment on lines +29 to +31
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# def __init__(
# self, digest_model: Union[DigestOnnxModel, DigestReportModel], parent=None
# ):

Delete comments

def __init__(self, digest_model: DigestModel, parent=None):
super().__init__(parent)
self.ui = Ui_modelSummary()
self.ui.setupUi(self)
apply_dark_style_sheet(self)

self.file: Optional[str] = None
self.ui.warningLabel.hide()
self.digest_model = digest_model
self.model_id = digest_model.unique_id
self.model_proto: Optional[ModelProto] = None
model_name: str = digest_model.model_name if digest_model.model_name else ""

self.png_file_path: Optional[str] = None
self.load_gif = QMovie(":/assets/gifs/load.gif")
# We set the size of the GIF to half the original
self.load_gif.setScaledSize(QSize(214, 120))
Expand All @@ -50,13 +53,143 @@ def __init__(
self.freeze_inputs: Optional[FreezeInputs] = None
self.freeze_window: Optional[QWidget] = None

self.model_type: Optional[SupportedModelTypes] = None

if isinstance(digest_model, DigestOnnxModel):
self.model_type = SupportedModelTypes.ONNX
self.model_proto = (
digest_model.model_proto if digest_model.model_proto else ModelProto()
)
self.freeze_inputs = FreezeInputs(self.model_proto, model_name)
self.ui.freezeButton.clicked.connect(self.open_freeze_inputs)
self.freeze_inputs.complete_signal.connect(self.close_freeze_window)
elif isinstance(digest_model, DigestReportModel):
self.model_type = SupportedModelTypes.REPORT

# Hide some of the components
self.ui.similarityCorrelation.hide()
self.ui.similarityCorrelationStatic.hide()

self.file = digest_model.filepath
self.setObjectName(model_name)
self.ui.modelName.setText(model_name)
if self.file:
self.ui.modelFilename.setText(self.file)

self.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y"))

self.ui.parameters.setText(format(digest_model.parameters, ","))

node_type_counts = digest_model.node_type_counts
if len(node_type_counts) < 15:
bar_spacing = 40
else:
bar_spacing = 20
self.ui.opHistogramChart.bar_spacing = bar_spacing
self.ui.opHistogramChart.set_data(node_type_counts)
self.ui.nodes.setText(str(sum(node_type_counts.values())))

# Format flops with commas if available
flops_str = "N/A"
if digest_model.flops is not None:
flops_str = format(digest_model.flops, ",")

# Set up the FLOPs pie chart
pie_chart_labels, pie_chart_data = zip(
*digest_model.node_type_flops.items()
)
self.ui.flopsPieChart.set_data(
"FLOPs Intensity Per Op Type",
pie_chart_labels,
pie_chart_data,
)

# Set up the params pie chart
pie_chart_labels, pie_chart_data = zip(
*digest_model.node_type_parameters.items()
)
self.ui.parametersPieChart.set_data(
"Parameter Intensity Per Op Type",
pie_chart_labels,
pie_chart_data,
)

self.ui.flops.setText(flops_str)

# Inputs Table
self.ui.inputsTable.setRowCount(len(digest_model.model_inputs))

for row_idx, (input_name, input_info) in enumerate(
digest_model.model_inputs.items()
):
self.ui.inputsTable.setItem(row_idx, 0, QTableWidgetItem(input_name))
self.ui.inputsTable.setItem(
row_idx, 1, QTableWidgetItem(str(input_info.shape))
)
self.ui.inputsTable.setItem(
row_idx, 2, QTableWidgetItem(str(input_info.dtype))
)
self.ui.inputsTable.setItem(
row_idx, 3, QTableWidgetItem(str(input_info.size_kbytes))
)

self.ui.inputsTable.resizeColumnsToContents()
self.ui.inputsTable.resizeRowsToContents()

# Outputs Table
self.ui.outputsTable.setRowCount(len(digest_model.model_outputs))
for row_idx, (output_name, output_info) in enumerate(
digest_model.model_outputs.items()
):
self.ui.outputsTable.setItem(row_idx, 0, QTableWidgetItem(output_name))
self.ui.outputsTable.setItem(
row_idx, 1, QTableWidgetItem(str(output_info.shape))
)
self.ui.outputsTable.setItem(
row_idx, 2, QTableWidgetItem(str(output_info.dtype))
)
self.ui.outputsTable.setItem(
row_idx, 3, QTableWidgetItem(str(output_info.size_kbytes))
)

self.ui.outputsTable.resizeColumnsToContents()
self.ui.outputsTable.resizeRowsToContents()

if isinstance(digest_model, DigestOnnxModel):

if digest_model.model_version:
# ModelProto Info
self.ui.modelProtoTable.setItem(
0, 1, QTableWidgetItem(digest_model.model_version)
)

if digest_model.graph_name:
self.ui.modelProtoTable.setItem(
1, 1, QTableWidgetItem(digest_model.graph_name)
)

producer_txt = (
f"{digest_model.producer_name} {digest_model.producer_version}"
)
self.ui.modelProtoTable.setItem(2, 1, QTableWidgetItem(producer_txt))

self.ui.modelProtoTable.setItem(
3, 1, QTableWidgetItem(str(digest_model.ir_version))
)

for domain, version in digest_model.imports.items():
row_idx = self.ui.importsTable.rowCount()
self.ui.importsTable.insertRow(row_idx)
if domain == "" or domain == "ai.onnx":
self.ui.opsetVersion.setText(str(version))
domain = "ai.onnx"
self.ui.importsTable.setItem(row_idx, 0, QTableWidgetItem(domain))
self.ui.importsTable.setItem(row_idx, 1, QTableWidgetItem(str(version)))
row_idx += 1

self.ui.importsTable.resizeColumnsToContents()
self.ui.modelProtoTable.resizeColumnsToContents()
self.setObjectName(model_name)

def open_freeze_inputs(self):
if self.freeze_inputs:
Expand Down
4 changes: 4 additions & 0 deletions src/digest/multi_model_selection_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def set_directory(self, directory: str):

total_num_models = len(onnx_file_list) + len(report_file_list)

if total_num_models == 0:
self.update_message_label("No models found in the selected directory.")
return

serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list)

progress = ProgressDialog("Loading models", total_num_models, self)
Expand Down
Loading
Loading