commit bce43c8b3321f7a5a49cda04b46329e29b741942 Author: Xe Iaso Date: Tue Jan 7 18:30:07 2025 -0500 first commit diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..777639c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,16 @@ +FROM ghcr.io/lecode-official/comfyui-docker:latest + +RUN pip install Flask \ + && git clone https://github.com/TemryL/ComfyS3 /opt/comfyui/custom_nodes/comfys3 \ + && pip install -r /opt/comfyui/custom_nodes/comfys3/requirements.txt \ + && rm -rf /opt/comfyui/custom_nodes/comfys3/.git \ + && rm /opt/comfyui/custom_nodes/comfys3/.env \ + && touch /opt/comfyui/custom_nodes/comfys3/.env \ + && git clone https://github.com/Ttl/ComfyUi_NNLatentUpscale /opt/comfyui/custom_nodes/ComfyUi_NNLatentUpscale \ + && rm -rf /opt/comfyui/custom_nodes/ComfyUi_NNLatentUpscale/.git + +COPY waifuwave.py . +COPY fetch_models.py . +COPY startup.sh . + +CMD ["/opt/comfyui/startup.sh"] \ No newline at end of file diff --git a/fetch_models.py b/fetch_models.py new file mode 100644 index 0000000..c4408cb --- /dev/null +++ b/fetch_models.py @@ -0,0 +1,94 @@ +from multiprocessing import Pool +from typing import Generator, Iterable, List +from urllib.parse import urlparse + +import os +import boto3 + + +models = [ + "checkpoints/counterfeitV30_v30.safetensors", + "embeddings/7dirtywords.pt", + "embeddings/easynegative.safetensors", + "embeddings/negative_hand-neg.pt", + "loras/pastelMixStylizedAnime_pastelMixLoraVersion.safetensors", + "loras/ligne_claire_anime.safetensors", + "vae/sdVAEForAnime_v10.pt", +] + + +def batcher(iterable: Iterable, batch_size: int) -> Generator[List, None, None]: + """Batch an iterator. The last item might be of smaller len than batch_size. + + Args: + iterable (Iterable): Any iterable that should be batched + batch_size (int): Len of the generated lists + + Yields: + Generator[List, None, None]: List of items in iterable + """ + batch = [] + counter = 0 + for i in iterable: + batch.append(i) + counter += 1 + if counter % batch_size == 0: + yield batch + batch = [] + if len(batch) > 0: + yield batch + + +def download_batch(batch) -> int: + s3 = boto3.client("s3") + n = 0 + for line in batch: + line, destdir = line + url = urlparse(line) + url_path = url.path.lstrip("/") + + folder, basename = os.path.split(url_path) + + dir = os.path.join(destdir, folder) + os.makedirs(dir, exist_ok=True) + filepath = os.path.join(dir, basename) + + if os.path.exists(filepath): + print(f"{line} already exists") + continue + + print(f"{line} -> {filepath}") + s3.download_file(url.netloc, url_path, filepath) + n += 1 + return n + + +def copy_from_tigris( + models: List[str] = models, + bucket_name: str = os.getenv("BUCKET_NAME", "comfyui"), + destdir: str = "/opt/comfyui", + n_cpus: int = os.cpu_count() + ): + """Copy files from Tigris to the destination folder. This will be done in parallel. + + Args: + models (List[str]): List of models to download. Defaults to the list of models in this file. + bucket_name (str): Tigris bucket to query. Defaults to envvar $BUCKET_NAME. + destdir (str): path to store the files. + n_cpus (int): number of simultaneous batches. Defaults to the number of cpus in the computer. + """ + + model_files = [ (f"s3://{bucket_name}/models/{x}", destdir) for x in models ] + + print(f"using {n_cpus} cpu cores for downloads") + n_cpus = min(len(model_files), n_cpus) + batch_size = len(model_files) // n_cpus + with Pool(processes=n_cpus) as pool: + for n in pool.imap_unordered( + download_batch, batcher(model_files, batch_size) + ): + pass + + +if __name__ == "__main__": + copy_from_tigris(n_cpus=999) \ No newline at end of file diff --git a/fly.toml b/fly.toml new file mode 100644 index 0000000..43a0395 --- /dev/null +++ b/fly.toml @@ -0,0 +1,26 @@ +# fly.toml app configuration file generated for waifuwave on 2025-01-07T17:07:19-05:00 +# +# See https://fly.io/docs/reference/configuration/ for information about how to use this file. +# + +app = 'waifuwave' +primary_region = 'ord' +vm.size = "l40s" + +[build] + +[http_service] +internal_port = 8080 +force_https = true +auto_stop_machines = 'stop' +auto_start_machines = true +min_machines_running = 0 +processes = ['app'] + + +[[http_service.checks]] +grace_period = "10s" +interval = "5s" +method = "GET" +timeout = "5s" +path = "/" diff --git a/startup.sh b/startup.sh new file mode 100755 index 0000000..7048693 --- /dev/null +++ b/startup.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +export S3_REGION="${AWS_REGION}" +export S3_ACCESS_KEY="${AWS_ACCESS_KEY_ID}" +export S3_SECRET_KEY="${AWS_SECRET_ACCESS_KEY}" +export S3_BUCKET_NAME="${BUCKET_NAME}" +export S3_ENDPOINT_URL="${AWS_ENDPOINT_URL_S3}" +export S3_INPUT_DIR="input" +export S3_OUTPUT_DIR="output" + +python fetch_models.py + +python waifuwave.py --host=0.0.0.0 --port=8080 \ No newline at end of file diff --git a/waifuwave.py b/waifuwave.py new file mode 100644 index 0000000..7e1afb5 --- /dev/null +++ b/waifuwave.py @@ -0,0 +1,295 @@ +import os +import random +import sys +from typing import Sequence, Mapping, Any, Union +import torch +import boto3 +from flask import Flask, jsonify, request + + +app = Flask(__name__) + + +def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: + """Returns the value at the given index of a sequence or mapping. + + If the object is a sequence (like list or string), returns the value at the given index. + If the object is a mapping (like a dictionary), returns the value at the index-th key. + + Some return a dictionary, in these cases, we look for the "results" key + + Args: + obj (Union[Sequence, Mapping]): The object to retrieve the value from. + index (int): The index of the value to retrieve. + + Returns: + Any: The value at the given index. + + Raises: + IndexError: If the index is out of bounds for the object and the object is not a mapping. + """ + try: + return obj[index] + except KeyError: + return obj["result"][index] + + +def find_path(name: str, path: str = None) -> str: + """ + Recursively looks at parent folders starting from the given path until it finds the given name. + Returns the path as a Path object if found, or None otherwise. + """ + # If no path is given, use the current working directory + if path is None: + path = os.getcwd() + + # Check if the current directory contains the name + if name in os.listdir(path): + path_name = os.path.join(path, name) + print(f"{name} found: {path_name}") + return path_name + + # Get the parent directory + parent_directory = os.path.dirname(path) + + # If the parent directory is the same as the current directory, we've reached the root and stop the search + if parent_directory == path: + return None + + # Recursively call the function with the parent directory + return find_path(name, parent_directory) + + +def add_comfyui_directory_to_sys_path() -> None: + """ + Add 'ComfyUI' to the sys.path + """ + comfyui_path = find_path("ComfyUI") + if comfyui_path is not None and os.path.isdir(comfyui_path): + sys.path.append(comfyui_path) + print(f"'{comfyui_path}' added to sys.path") + + +def add_extra_model_paths() -> None: + """ + Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path. + """ + try: + from server import load_extra_path_config + except ImportError: + print( + "Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead." + ) + try: + from utils.extra_config import load_extra_path_config + except ImportError: + return + + + extra_model_paths = find_path("extra_model_paths.yaml") + + if extra_model_paths is not None: + load_extra_path_config(extra_model_paths) + else: + print("Could not find the extra_model_paths config file.") + + +add_comfyui_directory_to_sys_path() +add_extra_model_paths() + + +def import_custom_nodes() -> None: + """Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS + + This function sets up a new asyncio event loop, initializes the PromptServer, + creates a PromptQueue, and initializes the custom nodes. + """ + import asyncio + import execution + from nodes import init_extra_nodes + import server + + # Creating a new event loop and setting it as the default loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Creating an instance of PromptServer with the loop + server_instance = server.PromptServer(loop) + execution.PromptQueue(server_instance) + + # Initializing custom nodes + init_extra_nodes() + + +from nodes import ( + VAEDecode, + KSampler, + NODE_CLASS_MAPPINGS, + VAELoader, + VAEEncode, + CheckpointLoaderSimple, + CLIPTextEncode, + EmptyLatentImage, + LoraLoader, +) + + +def generate_image(prompt: str, negative_prompt: str): + import_custom_nodes() + with torch.inference_mode(): + checkpointloadersimple = CheckpointLoaderSimple() + checkpointloadersimple_4 = checkpointloadersimple.load_checkpoint( + ckpt_name="counterfeitV30_v30.safetensors" + ) + + emptylatentimage = EmptyLatentImage() + emptylatentimage_5 = emptylatentimage.generate( + width=768, height=384, batch_size=1 + ) + + loraloader = LoraLoader() + loraloader_51 = loraloader.load_lora( + lora_name="pastelMixStylizedAnime_pastelMixLoraVersion.safetensors", + strength_model=1, + strength_clip=1, + model=get_value_at_index(checkpointloadersimple_4, 0), + clip=get_value_at_index(checkpointloadersimple_4, 1), + ) + + loraloader_61 = loraloader.load_lora( + lora_name="ligne_claire_anime.safetensors", + strength_model=1, + strength_clip=1, + model=get_value_at_index(loraloader_51, 0), + clip=get_value_at_index(loraloader_51, 1), + ) + + cliptextencode = CLIPTextEncode() + cliptextencode_6 = cliptextencode.encode( + text=f"(masterpiece, best quality), {prompt}", + clip=get_value_at_index(loraloader_61, 1), + ) + + vaeloader = VAELoader() + vaeloader_12 = vaeloader.load_vae(vae_name="sdVAEForAnime_v10.pt") + + cliptextencode_38 = cliptextencode.encode( + text=f"embedding:easynegative, embedding:negative_hand-neg, embedding:7dirtywords, {negative_prompt}", + clip=get_value_at_index(loraloader_61, 1), + ) + + ksampler = KSampler() + ksampler_3 = ksampler.sample( + seed=random.randint(1, 2**64), + steps=26, + cfg=6, + sampler_name="dpmpp_2m", + scheduler="karras", + denoise=1, + model=get_value_at_index(loraloader_61, 0), + positive=get_value_at_index(cliptextencode_6, 0), + negative=get_value_at_index(cliptextencode_38, 0), + latent_image=get_value_at_index(emptylatentimage_5, 0), + ) + + vaedecode = VAEDecode() + vaedecode_47 = vaedecode.decode( + samples=get_value_at_index(ksampler_3, 0), + vae=get_value_at_index(vaeloader_12, 0), + ) + + imagesharpen = NODE_CLASS_MAPPINGS["ImageSharpen"]() + imagesharpen_85 = imagesharpen.sharpen( + sharpen_radius=1, + sigma=1, + alpha=1, + image=get_value_at_index(vaedecode_47, 0), + ) + + vaeencode = VAEEncode() + vaeencode_86 = vaeencode.encode( + pixels=get_value_at_index(imagesharpen_85, 0), + vae=get_value_at_index(vaeloader_12, 0), + ) + + nnlatentupscale = NODE_CLASS_MAPPINGS["NNLatentUpscale"]() + saveimages3 = NODE_CLASS_MAPPINGS["SaveImageS3"]() + + nnlatentupscale_31 = nnlatentupscale.upscale( + version="SD 1.x", + upscale=2.0, + latent=get_value_at_index(vaeencode_86, 0), + ) + + ksampler_53 = ksampler.sample( + seed=random.randint(1, 2**64), + steps=30, + cfg=6, + sampler_name="dpmpp_2m", + scheduler="karras", + denoise=1, + model=get_value_at_index(loraloader_61, 0), + positive=get_value_at_index(cliptextencode_6, 0), + negative=get_value_at_index(cliptextencode_38, 0), + latent_image=get_value_at_index(nnlatentupscale_31, 0), + ) + + vaedecode_42 = vaedecode.decode( + samples=get_value_at_index(ksampler_53, 0), + vae=get_value_at_index(vaeloader_12, 0), + ) + + saveimages3_89 = saveimages3.save_images( + filename_prefix="waifu", images=get_value_at_index(vaedecode_42, 0) + ) + + return get_value_at_index(saveimages3_89, 0) + + +def generate_presigned_url(bucket_name: str, object_name: str, expiration: int = 3600): + s3_client = boto3.client( + "s3", + aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + endpoint_url=os.environ["AWS_ENDPOINT_URL_S3"], + region_name=os.environ.get("AWS_REGION", None), + ) + + try: + response = s3_client.generate_presigned_url( + "get_object", + Params={"Bucket": bucket_name, "Key": object_name}, + ExpiresIn=expiration, + ) + except Exception as e: + print(f"Error generating presigned URL: {e}") + return None + + return response + + +@app.route("/", methods=["GET"]) +def read_root(): + return jsonify({"Hello": "World"}) + + +@app.route("/generate", methods=["POST"]) +def generate(): + content_type = request.headers.get('Content-Type') + if (content_type == 'application/json'): + json = request.json + else: + return 'Content-Type not supported!' + + image_response = generate_image(json["prompt"], json["negative_prompt"]) + + return jsonify({ + "fname": image_response, + "url": generate_presigned_url( + os.getenv("BUCKET_NAME", "comfyui"), image_response[0] + ), + }) + + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=8080)