From e7089b8a7ba046ed6d31d35653395331fdea513b Mon Sep 17 00:00:00 2001 From: kaj <40004347+KAJdev@users.noreply.github.com> Date: Thu, 20 Jul 2023 11:39:21 -0800 Subject: [PATCH] generation requests return images --- .github/workflows/comfy_windows.yml | 3 + .../stablestudio-ui/src-tauri/src/main.rs | 1 - .../stablestudio-ui/src-tauri/src/server.rs | 5 + packages/stablestudio-ui/src/App/index.tsx | 1 + packages/stablestudio-ui/src/Comfy/index.tsx | 173 ++++++++++++++++- packages/stablestudio-ui/src/Comfy/plugin.ts | 180 ++++++++++-------- .../src/Generation/Image/Create/index.tsx | 122 +----------- .../src/Generation/Image/Images/index.tsx | 2 +- .../src/Generation/Image/Output/State.tsx | 7 +- 9 files changed, 289 insertions(+), 205 deletions(-) diff --git a/.github/workflows/comfy_windows.yml b/.github/workflows/comfy_windows.yml index d654b70..3eed341 100644 --- a/.github/workflows/comfy_windows.yml +++ b/.github/workflows/comfy_windows.yml @@ -6,6 +6,9 @@ on: branches: - tauri +concurrency: + group: comfyui_windows + jobs: repackage_comfyui: permissions: diff --git a/packages/stablestudio-ui/src-tauri/src/main.rs b/packages/stablestudio-ui/src-tauri/src/main.rs index c81ec6e..20da6e0 100644 --- a/packages/stablestudio-ui/src-tauri/src/main.rs +++ b/packages/stablestudio-ui/src-tauri/src/main.rs @@ -2,7 +2,6 @@ #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] use std::collections::HashMap; -use std::fmt::format; use std::fs::File; use std::sync::OnceLock; use tauri::api::process::CommandEvent; diff --git a/packages/stablestudio-ui/src-tauri/src/server.rs b/packages/stablestudio-ui/src-tauri/src/server.rs index 0d81ab1..980884e 100644 --- a/packages/stablestudio-ui/src-tauri/src/server.rs +++ b/packages/stablestudio-ui/src-tauri/src/server.rs @@ -106,6 +106,11 @@ impl Builder { "`ws${window.location.protocol === \"https:\" ? \"s\" : \"\"}://${location.host}/ws${existingSession}`", "`ws://localhost:5000/ws${existingSession}`" ); + + // add some stuff to app.js + if path_name.ends_with("app.js") && !file_contents.ends_with("app.api = api;") { + file_contents = file_contents + "\napp.api = api;"; + } let response = HttpResponse::from_data(file_contents.as_bytes()).with_status_code(200).with_header( Header::from_bytes("Content-Type", mimetype.unwrap().as_str()).unwrap(), diff --git a/packages/stablestudio-ui/src/App/index.tsx b/packages/stablestudio-ui/src/App/index.tsx index badd5af..2f17d24 100644 --- a/packages/stablestudio-ui/src/App/index.tsx +++ b/packages/stablestudio-ui/src/App/index.tsx @@ -222,6 +222,7 @@ export namespace App { setRunning(true); setIsSetup(SetupState.ComfyRunning); + Comfy.registerListeners(); }, [isSetup, print, setRunning, setUnlisteners]); useEffect(() => { diff --git a/packages/stablestudio-ui/src/Comfy/index.tsx b/packages/stablestudio-ui/src/Comfy/index.tsx index 5b2ca6e..b7639d2 100644 --- a/packages/stablestudio-ui/src/Comfy/index.tsx +++ b/packages/stablestudio-ui/src/Comfy/index.tsx @@ -1,7 +1,9 @@ +import * as StableStudio from "@stability/stablestudio-plugin"; import { useLocation } from "react-router-dom"; import { create } from "zustand"; +import { Generation } from "~/Generation"; -export type Comfy = { +export type ComfyApp = { setup: () => void; registerNodes: () => void; loadGraphData: (graph: Graph) => void; @@ -13,6 +15,21 @@ export type Comfy = { refreshComboInNodes: () => Promise; queuePrompt: (number: number, batchCount: number) => Promise; clean: () => void; + api: ComfyAPI; +}; + +export type ComfyAPI = { + addEventListener: (event: string, callback: (detail: any) => void) => void; +}; + +export type Comfy = { app: ComfyApp; api: ComfyAPI }; + +export type ComfyOutput = { + images: { + filename: string; + subfolder: string; + type: string; + }[]; }; export type Graph = { @@ -91,11 +108,11 @@ type State = { }; export namespace Comfy { - export const get = (): Comfy | null => + export const get = (): ComfyApp | null => (( (document.getElementById("comfyui-window") as HTMLIFrameElement) - ?.contentWindow as Window & { app: Comfy } - )?.app as Comfy) ?? null; + ?.contentWindow as Window & { app: ComfyApp } + )?.app as ComfyApp) ?? null; export const use = create((set) => ({ output: [], @@ -115,4 +132,152 @@ export namespace Comfy { unlisteners: [], setUnlisteners: (unlisteners) => set({ unlisteners }), })); + + export const registerListeners = async () => { + let api = get()?.api; + + while (!api) { + await new Promise((resolve) => setTimeout(resolve, 1000)); + api = get()?.api; + } + + api.addEventListener("executed", async ({ detail }) => { + const { output, prompt_id } = detail; + + console.log("executed_in_comfy_domain", detail); + + const newInputs: Record = {}; + const responses: Generation.Images = []; + + const input = Generation.Image.Input.get(prompt_id); + + const images = await Promise.all( + (output as ComfyOutput).images.map(async (image) => { + console.log("image", image); + const resp = await fetch( + `http://localhost:3000/view?filename=${image.filename}&subfolder=${ + image.subfolder || "" + }&type=${image.type}`, + { + cache: "no-cache", + } + ); + + const blob = await resp.blob(); + const url = URL.createObjectURL(blob); + console.log("url", url); + + const output = Generation.Image.Output.get(prompt_id); + + return { + id: ID.create(), + blob, + inputID: output?.inputID ?? "", + createdAt: new Date(), + }; + }) + ); + + for (const image of images) { + const inputID = ID.create(); + const newInput = { + ...Generation.Image.Input.initial(inputID), + ...input, + seed: 0, + id: inputID, + }; + + const cropped = await cropImage(image, newInput); + if (!cropped) continue; + + responses.push(cropped); + newInputs[inputID] = newInput; + } + + Generation.Image.Inputs.set({ + ...Generation.Image.Inputs.get(), + ...newInputs, + }); + responses.forEach(Generation.Image.add); + Generation.Image.Output.received(prompt_id, responses); + }); + + api.addEventListener("execution_start", ({ detail }) => { + const { prompt_id } = detail; + + console.log("execution_start", detail); + + if (prompt_id) { + let input = Generation.Image.Input.get(prompt_id); + if (!input) { + input = Generation.Image.Input.initial(prompt_id); + Generation.Image.Inputs.set((inputs) => ({ + ...inputs, + [prompt_id]: input, + })); + } + const output = Generation.Image.Output.requested( + prompt_id, + {}, + prompt_id + ); + Generation.Image.Output.set(output); + } + }); + + api.addEventListener("execution_error", ({ detail }) => { + console.log("execution_error", detail); + Generation.Image.Output.clear(detail.prompt_id); + }); + + console.log("registered ComfyUI listeners"); + }; +} + +function cropImage( + image: StableStudio.StableDiffusionImage, + input: Generation.Image.Input +) { + return new Promise((resolve) => { + const id = image.id; + const blob = image.blob; + if (!blob || !id) return resolve(); + + // crop image to box size + const croppedCanvas = document.createElement("canvas"); + croppedCanvas.width = input.width; + croppedCanvas.height = input.height; + + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const croppedCtx = croppedCanvas.getContext("2d")!; + + const img = new window.Image(); + img.src = URL.createObjectURL(blob); + img.onload = () => { + croppedCtx.drawImage( + img, + 0, + 0, + input.width, + input.height, + 0, + 0, + input.width, + input.height + ); + + croppedCanvas.toBlob((blob) => { + if (blob) { + const objectURL = URL.createObjectURL(blob); + resolve({ + id, + inputID: input.id, + created: new Date(), + src: objectURL, + finishReason: 0, + }); + } + }); + }; + }); } diff --git a/packages/stablestudio-ui/src/Comfy/plugin.ts b/packages/stablestudio-ui/src/Comfy/plugin.ts index f90a2d5..3d583a1 100644 --- a/packages/stablestudio-ui/src/Comfy/plugin.ts +++ b/packages/stablestudio-ui/src/Comfy/plugin.ts @@ -1,94 +1,112 @@ import * as StableStudio from "@stability/stablestudio-plugin"; -import { Comfy } from "."; +export const createPlugin = StableStudio.createPlugin(({ set, get }) => { + return { + manifest: { + name: "ComfyUI Backend", + }, -export const createPlugin = StableStudio.createPlugin(({ set, get }) => ({ - manifest: { - name: "ComfyUI Backend", - }, + statusStuff: { + indicator: "loading", + text: "Starting", + }, - statusStuff: { - indicator: "loading", - text: "Starting", - }, + // createStableDiffusionImages: async () => { + // const comfy = Comfy.get(); - createStableDiffusionImages: async () => { - const comfy = Comfy.get(); + // if (!comfy) { + // console.log(document.getElementById("comfyui-window")); + // throw new Error("ComfyUI is not loaded"); + // } - if (!comfy) { - console.log(document.getElementById("comfyui-window")); - throw new Error("ComfyUI is not loaded"); - } + // await comfy.queuePrompt(1, 1); - await comfy.queuePrompt(1, 1); + // const p = new Promise((resolve, reject) => { - const image = await fetch(`${window.location.origin}/DummyImage.png`); - const blob = await image.blob(); - const createdAt = new Date(); + // }); - return { - id: `${Math.random() * 10000000}`, - images: [ - { - id: `${Math.random() * 10000000}`, - createdAt, - blob, - }, - { - id: `${Math.random() * 10000000}`, - createdAt, - blob, - }, - { - id: `${Math.random() * 10000000}`, - createdAt, - blob, - }, - { - id: `${Math.random() * 10000000}`, - createdAt, - blob, - }, - ], - }; - }, + // const image = await fetch(`${window.location.origin}/DummyImage.png`); + // const blob = await image.blob(); + // const createdAt = new Date(); - getStableDiffusionModels: async () => { - const resp = await fetch("/object_info", { cache: "no-cache" }); - const jsonResp = await resp.json(); + // return { + // id: `${Math.random() * 10000000}`, + // images: [ + // { + // id: `${Math.random() * 10000000}`, + // createdAt, + // blob, + // }, + // { + // id: `${Math.random() * 10000000}`, + // createdAt, + // blob, + // }, + // { + // id: `${Math.random() * 10000000}`, + // createdAt, + // blob, + // }, + // { + // id: `${Math.random() * 10000000}`, + // createdAt, + // blob, + // }, + // ], + // }; + // }, - console.log(jsonResp); - - return jsonResp?.CheckpointLoader?.input?.required?.ckpt_name?.[0]?.map( - (fileName: string) => ({ - id: fileName, - name: fileName, - }) - ); - }, - - getStableDiffusionSamplers: async () => { - const resp = await fetch("/object_info", { cache: "no-cache" }); - const jsonResp = await resp.json(); - - return jsonResp?.KSampler?.input?.required?.scheduler?.[0]?.map( - (name: string) => ({ - id: name, - name, - }) - ); - }, - - getStatus: () => { - fetch("/comfyui", { cache: "no-cache" }).then((resp) => { - set({ - statusStuff: { - indicator: resp.ok ? "success" : "error", - text: resp.ok ? "Running" : "Not Running", - }, + getStableDiffusionModels: async () => { + const resp = await fetch("/object_info/CheckpointLoader", { + cache: "no-cache", }); - }); + const jsonResp = await resp.json(); - return get().statusStuff; - }, -})); + console.log(jsonResp); + + return jsonResp?.CheckpointLoader?.input?.required?.ckpt_name?.[0]?.map( + (fileName: string) => ({ + id: fileName, + name: fileName, + }) + ); + }, + + getStableDiffusionSamplers: async () => { + const resp = await fetch("/object_info/KSamplerAdvanced", { + cache: "no-cache", + }); + const jsonResp = await resp.json(); + + return jsonResp?.KSamplerAdvanced?.input?.required?.sampler_name?.[0]?.map( + (name: string) => ({ + id: name, + name: name + .replace(/_/g, " ") + .replace("ddim", "DDIM") + .replace("lms", "LMS") + .replace("dpm", "DPM") + .replace("pp", "PP") + .replace("sde", "SDE") + .replace("2m", "2M") + .replace("2s", "2S") + .replace("gpu", "GPU") + .replace(/\w\S*/g, (w) => w.replace(/^\w/, (c) => c.toUpperCase())), + }) + ); + }, + + getStatus: () => { + fetch("/comfyui", { cache: "no-cache" }).then((resp) => { + set({ + statusStuff: { + indicator: resp.ok ? "success" : "error", + text: resp.ok ? "Running" : "Not Running", + }, + }); + }); + + return get().statusStuff; + }, + }; +}); diff --git a/packages/stablestudio-ui/src/Generation/Image/Create/index.tsx b/packages/stablestudio-ui/src/Generation/Image/Create/index.tsx index 382ba81..9b2b1c3 100644 --- a/packages/stablestudio-ui/src/Generation/Image/Create/index.tsx +++ b/packages/stablestudio-ui/src/Generation/Image/Create/index.tsx @@ -1,9 +1,7 @@ -import * as StableStudio from "@stability/stablestudio-plugin"; -import throttledQueue from "throttled-queue"; +import { Comfy } from "~/Comfy"; import { Generation } from "~/Generation"; import { GlobalState } from "~/GlobalState"; -import { Plugin } from "~/Plugin"; import { Button } from "./Button"; @@ -23,36 +21,24 @@ export namespace Create { ) => void; }; - namespace Throttle { - const requestsPerInterval = 1; - const interval = 500; - const spaceEvenly = true; - - const queue = throttledQueue(requestsPerInterval, interval, spaceEvenly); - export const wait = () => queue(() => Promise.resolve()); - } - export const execute = async ({ + // eslint-disable-next-line @typescript-eslint/no-unused-vars, no-unused-vars count = Generation.Image.Count.preset(), input, onStarted = doNothing, onException = doNothing, + // eslint-disable-next-line @typescript-eslint/no-unused-vars, no-unused-vars onSuccess = doNothing, onFinished = doNothing, }: Handlers & { count: number; input: Generation.Image.Input; - }): Promise => { - const { createStableDiffusionImages } = Plugin.get(); + }): Promise => { try { - if (!createStableDiffusionImages) throw new Error("Plugin not found"); - Latest.set(new Date()); onStarted(); - await Throttle.wait(); - const initImg = await Generation.Image.Input.resizeInit(input); const pluginInput = await Generation.Image.Input.toInput( !initImg @@ -72,43 +58,7 @@ export namespace Create { pluginInput.width = Math.ceil((pluginInput.width ?? 512) / 64) * 64; } - const responses: Generation.Images = []; - const response = await createStableDiffusionImages({ - input: pluginInput, - count, - }); - - if (response instanceof Error) throw response; - if (!response || !response?.images || response?.images?.length <= 0) - throw new Error(); - - const newInputs: Record = {}; - - for (const image of response.images) { - const inputID = ID.create(); - const newInput = { - ...Generation.Image.Input.initial(inputID), - ...input, - seed: image.input?.seed ?? input.seed, - id: inputID, - }; - - const cropped = await cropImage(image, newInput); - if (!cropped) continue; - - responses.push(cropped); - newInputs[inputID] = newInput; - } - - Generation.Image.Inputs.set({ - ...Generation.Image.Inputs.get(), - ...newInputs, - }); - - onSuccess(responses); - onFinished(responses); - - return responses; + await Comfy.get()?.queuePrompt(1, 1); } catch (caught: unknown) { const exception = Generation.Image.Exception.create(caught); @@ -142,30 +92,24 @@ export namespace Create { ...modifiers, }; - const output = Generation.Image.Output.requested(inputID, modifiers); - return execute({ count: modifiers.count ?? Generation.Image.Count.get(), input, - onStarted: () => { - Generation.Image.Output.set(output); + onStarted: (output) => { onStarted(output); }, onException: (exception) => { showErrorSnackbar(exception); onException(exception); - Generation.Image.Output.clear(output.id); }, onSuccess: (images) => { - images.forEach(Generation.Image.add); onSuccess(images); }, onFinished: (result) => { - Generation.Image.Output.received(output.id, result); onFinished(result); }, }); @@ -174,10 +118,7 @@ export namespace Create { ); }; - export const useIsEnabled = () => - Plugin.use( - ({ createStableDiffusionImages }) => !!createStableDiffusionImages - ); + export const useIsEnabled = () => Comfy.use(({ running }) => running); export type Latest = Date; export namespace Latest { @@ -202,52 +143,3 @@ export namespace Create { } } } - -// TODO: Move somewhere else -function cropImage( - image: StableStudio.StableDiffusionImage, - input: Generation.Image.Input -) { - return new Promise((resolve) => { - const id = image.id; - const blob = image.blob; - if (!blob || !id) return resolve(); - - // crop image to box size - const croppedCanvas = document.createElement("canvas"); - croppedCanvas.width = input.width; - croppedCanvas.height = input.height; - - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const croppedCtx = croppedCanvas.getContext("2d")!; - - const img = new window.Image(); - img.src = URL.createObjectURL(blob); - img.onload = () => { - croppedCtx.drawImage( - img, - 0, - 0, - input.width, - input.height, - 0, - 0, - input.width, - input.height - ); - - croppedCanvas.toBlob((blob) => { - if (blob) { - const objectURL = URL.createObjectURL(blob); - resolve({ - id, - inputID: input.id, - created: new Date(), - src: objectURL, - finishReason: 0, - }); - } - }); - }; - }); -} diff --git a/packages/stablestudio-ui/src/Generation/Image/Images/index.tsx b/packages/stablestudio-ui/src/Generation/Image/Images/index.tsx index 36867fb..51573f9 100644 --- a/packages/stablestudio-ui/src/Generation/Image/Images/index.tsx +++ b/packages/stablestudio-ui/src/Generation/Image/Images/index.tsx @@ -165,7 +165,7 @@ export function Images({ className }: Images.Props) { style={{ height: virtualizer.getTotalSize() }} >
Generation.Image.Output; received: ( @@ -37,8 +38,8 @@ export namespace State { nextID: ID.create(), - requested: (inputID, modifiers) => { - const id = get().nextID; + requested: (inputID, modifiers, nextID) => { + const id = nextID ?? get().nextID; const output = { id, inputID,