generation requests return images

This commit is contained in:
kaj
2023-07-20 11:39:21 -08:00
parent 4d435a4d6d
commit e7089b8a7b
9 changed files with 289 additions and 205 deletions

View File

@@ -6,6 +6,9 @@ on:
branches:
- tauri
concurrency:
group: comfyui_windows
jobs:
repackage_comfyui:
permissions:

View File

@@ -2,7 +2,6 @@
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
use std::collections::HashMap;
use std::fmt::format;
use std::fs::File;
use std::sync::OnceLock;
use tauri::api::process::CommandEvent;

View File

@@ -106,6 +106,11 @@ impl Builder {
"`ws${window.location.protocol === \"https:\" ? \"s\" : \"\"}://${location.host}/ws${existingSession}`",
"`ws://localhost:5000/ws${existingSession}`"
);
// add some stuff to app.js
if path_name.ends_with("app.js") && !file_contents.ends_with("app.api = api;") {
file_contents = file_contents + "\napp.api = api;";
}
let response = HttpResponse::from_data(file_contents.as_bytes()).with_status_code(200).with_header(
Header::from_bytes("Content-Type", mimetype.unwrap().as_str()).unwrap(),

View File

@@ -222,6 +222,7 @@ export namespace App {
setRunning(true);
setIsSetup(SetupState.ComfyRunning);
Comfy.registerListeners();
}, [isSetup, print, setRunning, setUnlisteners]);
useEffect(() => {

View File

@@ -1,7 +1,9 @@
import * as StableStudio from "@stability/stablestudio-plugin";
import { useLocation } from "react-router-dom";
import { create } from "zustand";
import { Generation } from "~/Generation";
export type Comfy = {
export type ComfyApp = {
setup: () => void;
registerNodes: () => void;
loadGraphData: (graph: Graph) => void;
@@ -13,6 +15,21 @@ export type Comfy = {
refreshComboInNodes: () => Promise<void>;
queuePrompt: (number: number, batchCount: number) => Promise<void>;
clean: () => void;
api: ComfyAPI;
};
export type ComfyAPI = {
addEventListener: (event: string, callback: (detail: any) => void) => void;
};
export type Comfy = { app: ComfyApp; api: ComfyAPI };
export type ComfyOutput = {
images: {
filename: string;
subfolder: string;
type: string;
}[];
};
export type Graph = {
@@ -91,11 +108,11 @@ type State = {
};
export namespace Comfy {
export const get = (): Comfy | null =>
export const get = (): ComfyApp | null =>
((
(document.getElementById("comfyui-window") as HTMLIFrameElement)
?.contentWindow as Window & { app: Comfy }
)?.app as Comfy) ?? null;
?.contentWindow as Window & { app: ComfyApp }
)?.app as ComfyApp) ?? null;
export const use = create<State>((set) => ({
output: [],
@@ -115,4 +132,152 @@ export namespace Comfy {
unlisteners: [],
setUnlisteners: (unlisteners) => set({ unlisteners }),
}));
export const registerListeners = async () => {
let api = get()?.api;
while (!api) {
await new Promise((resolve) => setTimeout(resolve, 1000));
api = get()?.api;
}
api.addEventListener("executed", async ({ detail }) => {
const { output, prompt_id } = detail;
console.log("executed_in_comfy_domain", detail);
const newInputs: Record<ID, Generation.Image.Input> = {};
const responses: Generation.Images = [];
const input = Generation.Image.Input.get(prompt_id);
const images = await Promise.all(
(output as ComfyOutput).images.map(async (image) => {
console.log("image", image);
const resp = await fetch(
`http://localhost:3000/view?filename=${image.filename}&subfolder=${
image.subfolder || ""
}&type=${image.type}`,
{
cache: "no-cache",
}
);
const blob = await resp.blob();
const url = URL.createObjectURL(blob);
console.log("url", url);
const output = Generation.Image.Output.get(prompt_id);
return {
id: ID.create(),
blob,
inputID: output?.inputID ?? "",
createdAt: new Date(),
};
})
);
for (const image of images) {
const inputID = ID.create();
const newInput = {
...Generation.Image.Input.initial(inputID),
...input,
seed: 0,
id: inputID,
};
const cropped = await cropImage(image, newInput);
if (!cropped) continue;
responses.push(cropped);
newInputs[inputID] = newInput;
}
Generation.Image.Inputs.set({
...Generation.Image.Inputs.get(),
...newInputs,
});
responses.forEach(Generation.Image.add);
Generation.Image.Output.received(prompt_id, responses);
});
api.addEventListener("execution_start", ({ detail }) => {
const { prompt_id } = detail;
console.log("execution_start", detail);
if (prompt_id) {
let input = Generation.Image.Input.get(prompt_id);
if (!input) {
input = Generation.Image.Input.initial(prompt_id);
Generation.Image.Inputs.set((inputs) => ({
...inputs,
[prompt_id]: input,
}));
}
const output = Generation.Image.Output.requested(
prompt_id,
{},
prompt_id
);
Generation.Image.Output.set(output);
}
});
api.addEventListener("execution_error", ({ detail }) => {
console.log("execution_error", detail);
Generation.Image.Output.clear(detail.prompt_id);
});
console.log("registered ComfyUI listeners");
};
}
function cropImage(
image: StableStudio.StableDiffusionImage,
input: Generation.Image.Input
) {
return new Promise<Generation.Image | void>((resolve) => {
const id = image.id;
const blob = image.blob;
if (!blob || !id) return resolve();
// crop image to box size
const croppedCanvas = document.createElement("canvas");
croppedCanvas.width = input.width;
croppedCanvas.height = input.height;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const croppedCtx = croppedCanvas.getContext("2d")!;
const img = new window.Image();
img.src = URL.createObjectURL(blob);
img.onload = () => {
croppedCtx.drawImage(
img,
0,
0,
input.width,
input.height,
0,
0,
input.width,
input.height
);
croppedCanvas.toBlob((blob) => {
if (blob) {
const objectURL = URL.createObjectURL(blob);
resolve({
id,
inputID: input.id,
created: new Date(),
src: objectURL,
finishReason: 0,
});
}
});
};
});
}

View File

@@ -1,94 +1,112 @@
import * as StableStudio from "@stability/stablestudio-plugin";
import { Comfy } from ".";
export const createPlugin = StableStudio.createPlugin<any>(({ set, get }) => {
return {
manifest: {
name: "ComfyUI Backend",
},
export const createPlugin = StableStudio.createPlugin<any>(({ set, get }) => ({
manifest: {
name: "ComfyUI Backend",
},
statusStuff: {
indicator: "loading",
text: "Starting",
},
statusStuff: {
indicator: "loading",
text: "Starting",
},
// createStableDiffusionImages: async () => {
// const comfy = Comfy.get();
createStableDiffusionImages: async () => {
const comfy = Comfy.get();
// if (!comfy) {
// console.log(document.getElementById("comfyui-window"));
// throw new Error("ComfyUI is not loaded");
// }
if (!comfy) {
console.log(document.getElementById("comfyui-window"));
throw new Error("ComfyUI is not loaded");
}
// await comfy.queuePrompt(1, 1);
await comfy.queuePrompt(1, 1);
// const p = new Promise((resolve, reject) => {
const image = await fetch(`${window.location.origin}/DummyImage.png`);
const blob = await image.blob();
const createdAt = new Date();
// });
return {
id: `${Math.random() * 10000000}`,
images: [
{
id: `${Math.random() * 10000000}`,
createdAt,
blob,
},
{
id: `${Math.random() * 10000000}`,
createdAt,
blob,
},
{
id: `${Math.random() * 10000000}`,
createdAt,
blob,
},
{
id: `${Math.random() * 10000000}`,
createdAt,
blob,
},
],
};
},
// const image = await fetch(`${window.location.origin}/DummyImage.png`);
// const blob = await image.blob();
// const createdAt = new Date();
getStableDiffusionModels: async () => {
const resp = await fetch("/object_info", { cache: "no-cache" });
const jsonResp = await resp.json();
// return {
// id: `${Math.random() * 10000000}`,
// images: [
// {
// id: `${Math.random() * 10000000}`,
// createdAt,
// blob,
// },
// {
// id: `${Math.random() * 10000000}`,
// createdAt,
// blob,
// },
// {
// id: `${Math.random() * 10000000}`,
// createdAt,
// blob,
// },
// {
// id: `${Math.random() * 10000000}`,
// createdAt,
// blob,
// },
// ],
// };
// },
console.log(jsonResp);
return jsonResp?.CheckpointLoader?.input?.required?.ckpt_name?.[0]?.map(
(fileName: string) => ({
id: fileName,
name: fileName,
})
);
},
getStableDiffusionSamplers: async () => {
const resp = await fetch("/object_info", { cache: "no-cache" });
const jsonResp = await resp.json();
return jsonResp?.KSampler?.input?.required?.scheduler?.[0]?.map(
(name: string) => ({
id: name,
name,
})
);
},
getStatus: () => {
fetch("/comfyui", { cache: "no-cache" }).then((resp) => {
set({
statusStuff: {
indicator: resp.ok ? "success" : "error",
text: resp.ok ? "Running" : "Not Running",
},
getStableDiffusionModels: async () => {
const resp = await fetch("/object_info/CheckpointLoader", {
cache: "no-cache",
});
});
const jsonResp = await resp.json();
return get().statusStuff;
},
}));
console.log(jsonResp);
return jsonResp?.CheckpointLoader?.input?.required?.ckpt_name?.[0]?.map(
(fileName: string) => ({
id: fileName,
name: fileName,
})
);
},
getStableDiffusionSamplers: async () => {
const resp = await fetch("/object_info/KSamplerAdvanced", {
cache: "no-cache",
});
const jsonResp = await resp.json();
return jsonResp?.KSamplerAdvanced?.input?.required?.sampler_name?.[0]?.map(
(name: string) => ({
id: name,
name: name
.replace(/_/g, " ")
.replace("ddim", "DDIM")
.replace("lms", "LMS")
.replace("dpm", "DPM")
.replace("pp", "PP")
.replace("sde", "SDE")
.replace("2m", "2M")
.replace("2s", "2S")
.replace("gpu", "GPU")
.replace(/\w\S*/g, (w) => w.replace(/^\w/, (c) => c.toUpperCase())),
})
);
},
getStatus: () => {
fetch("/comfyui", { cache: "no-cache" }).then((resp) => {
set({
statusStuff: {
indicator: resp.ok ? "success" : "error",
text: resp.ok ? "Running" : "Not Running",
},
});
});
return get().statusStuff;
},
};
});

View File

@@ -1,9 +1,7 @@
import * as StableStudio from "@stability/stablestudio-plugin";
import throttledQueue from "throttled-queue";
import { Comfy } from "~/Comfy";
import { Generation } from "~/Generation";
import { GlobalState } from "~/GlobalState";
import { Plugin } from "~/Plugin";
import { Button } from "./Button";
@@ -23,36 +21,24 @@ export namespace Create {
) => void;
};
namespace Throttle {
const requestsPerInterval = 1;
const interval = 500;
const spaceEvenly = true;
const queue = throttledQueue(requestsPerInterval, interval, spaceEvenly);
export const wait = () => queue(() => Promise.resolve());
}
export const execute = async ({
// eslint-disable-next-line @typescript-eslint/no-unused-vars, no-unused-vars
count = Generation.Image.Count.preset(),
input,
onStarted = doNothing,
onException = doNothing,
// eslint-disable-next-line @typescript-eslint/no-unused-vars, no-unused-vars
onSuccess = doNothing,
onFinished = doNothing,
}: Handlers & {
count: number;
input: Generation.Image.Input;
}): Promise<Generation.Image.Exception | Generation.Images> => {
const { createStableDiffusionImages } = Plugin.get();
}): Promise<undefined | Generation.Image.Exception> => {
try {
if (!createStableDiffusionImages) throw new Error("Plugin not found");
Latest.set(new Date());
onStarted();
await Throttle.wait();
const initImg = await Generation.Image.Input.resizeInit(input);
const pluginInput = await Generation.Image.Input.toInput(
!initImg
@@ -72,43 +58,7 @@ export namespace Create {
pluginInput.width = Math.ceil((pluginInput.width ?? 512) / 64) * 64;
}
const responses: Generation.Images = [];
const response = await createStableDiffusionImages({
input: pluginInput,
count,
});
if (response instanceof Error) throw response;
if (!response || !response?.images || response?.images?.length <= 0)
throw new Error();
const newInputs: Record<ID, Generation.Image.Input> = {};
for (const image of response.images) {
const inputID = ID.create();
const newInput = {
...Generation.Image.Input.initial(inputID),
...input,
seed: image.input?.seed ?? input.seed,
id: inputID,
};
const cropped = await cropImage(image, newInput);
if (!cropped) continue;
responses.push(cropped);
newInputs[inputID] = newInput;
}
Generation.Image.Inputs.set({
...Generation.Image.Inputs.get(),
...newInputs,
});
onSuccess(responses);
onFinished(responses);
return responses;
await Comfy.get()?.queuePrompt(1, 1);
} catch (caught: unknown) {
const exception = Generation.Image.Exception.create(caught);
@@ -142,30 +92,24 @@ export namespace Create {
...modifiers,
};
const output = Generation.Image.Output.requested(inputID, modifiers);
return execute({
count: modifiers.count ?? Generation.Image.Count.get(),
input,
onStarted: () => {
Generation.Image.Output.set(output);
onStarted: (output) => {
onStarted(output);
},
onException: (exception) => {
showErrorSnackbar(exception);
onException(exception);
Generation.Image.Output.clear(output.id);
},
onSuccess: (images) => {
images.forEach(Generation.Image.add);
onSuccess(images);
},
onFinished: (result) => {
Generation.Image.Output.received(output.id, result);
onFinished(result);
},
});
@@ -174,10 +118,7 @@ export namespace Create {
);
};
export const useIsEnabled = () =>
Plugin.use(
({ createStableDiffusionImages }) => !!createStableDiffusionImages
);
export const useIsEnabled = () => Comfy.use(({ running }) => running);
export type Latest = Date;
export namespace Latest {
@@ -202,52 +143,3 @@ export namespace Create {
}
}
}
// TODO: Move somewhere else
function cropImage(
image: StableStudio.StableDiffusionImage,
input: Generation.Image.Input
) {
return new Promise<Generation.Image | void>((resolve) => {
const id = image.id;
const blob = image.blob;
if (!blob || !id) return resolve();
// crop image to box size
const croppedCanvas = document.createElement("canvas");
croppedCanvas.width = input.width;
croppedCanvas.height = input.height;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const croppedCtx = croppedCanvas.getContext("2d")!;
const img = new window.Image();
img.src = URL.createObjectURL(blob);
img.onload = () => {
croppedCtx.drawImage(
img,
0,
0,
input.width,
input.height,
0,
0,
input.width,
input.height
);
croppedCanvas.toBlob((blob) => {
if (blob) {
const objectURL = URL.createObjectURL(blob);
resolve({
id,
inputID: input.id,
created: new Date(),
src: objectURL,
finishReason: 0,
});
}
});
};
});
}

View File

@@ -165,7 +165,7 @@ export function Images({ className }: Images.Props) {
style={{ height: virtualizer.getTotalSize() }}
>
<div
className="absolute top-0 left-0 w-full"
className="absolute left-0 top-0 w-full"
style={{
transform: `translateY(${
(virtualItems[0]?.start ?? 0) - virtualizer.options.scrollMargin

View File

@@ -9,7 +9,8 @@ export type State = {
requested: (
inputID: ID,
modifiers?: Generation.Image.Input.Modifiers
modifiers?: Generation.Image.Input.Modifiers,
nextID?: ID
) => Generation.Image.Output;
received: (
@@ -37,8 +38,8 @@ export namespace State {
nextID: ID.create(),
requested: (inputID, modifiers) => {
const id = get().nextID;
requested: (inputID, modifiers, nextID) => {
const id = nextID ?? get().nextID;
const output = {
id,
inputID,