Skip to content

Commit

Permalink
feat(api-key): add API key management and validation
Browse files Browse the repository at this point in the history
Implement API key input in settings, validation before image
generation, and secure storage in .secrets.toml. Update GUI to
include API key setup prompts and error handling.
  • Loading branch information
rtuszik committed Sep 3, 2024
1 parent ea91b60 commit 9ccc139
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 10 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,4 @@ cython_debug/
#.idea/
settings.json
settings.local.toml
.secrets.toml
4 changes: 4 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@
load_dotenv=True,
lowercase_read=False,
)


def get_api_key():
return settings.get("REPLICATE_API_KEY") or settings.get("replicate_api_key")
70 changes: 62 additions & 8 deletions src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import httpx
import toml
from config import settings
from config import get_api_key, settings
from dynaconf import loaders
from loguru import logger
from nicegui import ui

Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(self, image_generator):
self.image_generator = image_generator
self.settings = settings
self.user_added_models = {}

self.api_key = get_api_key() or os.environ.get("REPLICATE_API_KEY", "")
self._attributes = [
"prompt",
"flux_model",
Expand Down Expand Up @@ -92,11 +93,11 @@ def __init__(self, image_generator):

def setup_ui(self):
ui.dark_mode().enable()
self.check_api_key()

with ui.grid(columns=2).classes("w-screen h-full gap-4 px-8"):
with ui.card().classes("col-span-2"):
ui.label("Flux LoRA API").classes("text-2xl font-bold mb-4")

with ui.card().classes("col-span-full"):
self.setup_top_panel()
with ui.card().classes("h-[70vh] overflow-auto"):
self.setup_left_panel()

Expand All @@ -108,8 +109,15 @@ def setup_ui(self):

logger.info("UI setup completed")

def setup_top_panel(self):
with ui.card().classes("w-full"):
ui.label("Flux LoRA API").classes("text-2xl font-bold")
ui.button(
icon="settings_suggest", on_click=self.open_settings_popup
).classes("absolute-right")

def setup_left_panel(self):
with ui.row().classes("w-full"):
with ui.row().classes("w-full items-end"):
self.replicate_model_select = (
ui.select(
options=self.model_options,
Expand All @@ -119,10 +127,10 @@ def setup_left_panel(self):
self.update_replicate_model(e.value)
),
)
.classes("w-5/6")
.classes("w-4/6 overflow-hidden")
.tooltip("Select or manage Replicate models")
)
ui.button(icon="settings_suggest").classes("w-1/6").on(
ui.button(icon="settings_suggest").classes("ml-2").on(
"click", self.open_user_model_popup
)

Expand Down Expand Up @@ -308,6 +316,38 @@ def setup_bottom_panel(self):
"w-full bg-blue-500 hover:bg-blue-600 text-white font-bold py-2 px-4 rounded"
)

async def open_settings_popup(self):
with ui.dialog() as dialog, ui.card():
ui.label("Settings").classes("text-2xl font-bold")
api_key_input = ui.input(
label="API Key",
placeholder="Enter Replicate API Key...",
password_toggle_button=True,
value=self.api_key,
).classes("w-full")

async def save_settings():
new_api_key = api_key_input.value
if new_api_key != self.api_key:
self.api_key = new_api_key
await self.save_api_key()
dialog.close()
ui.notify("Settings saved successfully", type="positive")

ui.button("Save Settings", on_click=save_settings).classes("mt-4")
dialog.open()

async def save_api_key(self):
settings.set("REPLICATE_API_KEY", self.api_key)

secrets_dict = {"default": {"REPLICATE_API_KEY": self.api_key}}

loaders.write(".secrets.toml", secrets_dict)

os.environ["REPLICATE_API_KEY"] = self.api_key

self.image_generator.set_api_key(self.api_key)

@ui.refreshable
def model_list(self):
for model in self.user_added_models:
Expand Down Expand Up @@ -410,6 +450,15 @@ async def toggle_custom_dimensions(self, e):
await self.save_settings()
logger.info(f"Custom dimensions toggled: {e.value}")

def check_api_key(self):
if not self.api_key:
ui.notify(
"No Replicate API Key found. Please set it in the settings before generating images.",
type="warning",
close_button="OK",
timeout=10000, # 10 seconds
)

async def reset_to_default(self):
with open("settings.toml", "r") as f:
default_settings = toml.load(f)["default"]
Expand All @@ -430,6 +479,11 @@ async def reset_to_default(self):
logger.info("Parameters reset to default values")

async def start_generation(self):
if not self.api_key:
ui.notify(
"Please set your Replicate API Key in the settings.", type="negative"
)
return
if not self.replicate_model_select.value:
ui.notify(
"Please select a Replicate model before generating images.",
Expand Down
9 changes: 8 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys

from config import get_api_key
from gui import create_gui
from loguru import logger
from nicegui import ui
Expand All @@ -18,6 +19,12 @@
generator = ImageGenerator()


api_key = get_api_key()
if api_key:
generator.set_api_key(api_key)
else:
logger.warning("No Replicate API Key found. Please set it in the settings.")

logger.info("Creating and setting up GUI")


Expand All @@ -29,4 +36,4 @@ async def main_page():

logger.info("Starting NiceGUI server")

ui.run(title="Replicate Flux LoRA", port=8080)
ui.run(title="Replicate Flux LoRA", port=8080, favicon="🚀")
17 changes: 16 additions & 1 deletion src/replicate_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import sys

import replicate
Expand All @@ -23,8 +24,14 @@
class ImageGenerator:
def __init__(self):
self.replicate_model = None
self.api_key = None
logger.info("ImageGenerator initialized")

def set_api_key(self, api_key):
self.api_key = api_key
os.environ["REPLICATE_API_KEY"] = api_key
logger.info("API key set")

def set_model(self, replicate_model):
self.replicate_model = replicate_model
logger.info(f"Model set to: {replicate_model}")
Expand All @@ -37,6 +44,13 @@ def generate_images(self, params):
logger.error(error_message)
raise ImageGenerationError(error_message)

if not self.api_key:
error_message = (
"No API key set. Please set an API key before generating images."
)
logger.error(error_message)
raise ImageGenerationError(error_message)

try:
flux_model = params.pop("flux_model", "dev")

Expand All @@ -47,7 +61,8 @@ def generate_images(self, params):
)
logger.info(f"Using Replicate model: {self.replicate_model}")

output = replicate.run(self.replicate_model, input=params)
client = replicate.Client(api_token=self.api_key)
output = client.run(self.replicate_model, input=params)

logger.success(f"Images generated successfully. Output: {output}")
return output
Expand Down

0 comments on commit 9ccc139

Please sign in to comment.