137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
import torch
|
|
from diffusers import (
|
|
StableDiffusionXLControlNetImg2ImgPipeline,
|
|
UniPCMultistepScheduler,
|
|
ControlNetModel,
|
|
StableDiffusionImg2ImgPipeline
|
|
)
|
|
from PIL import Image
|
|
import json
|
|
import numpy as np
|
|
|
|
|
|
# ============================
|
|
# 🚀 Inizializzazione pipeline
|
|
# ============================
|
|
def load_pipelines():
|
|
# ControlNet (canny)
|
|
controlnet = ControlNetModel.from_pretrained(
|
|
"lllyasviel/sd-controlnet-canny",
|
|
torch_dtype=torch.float16
|
|
)
|
|
|
|
# Stable Diffusion XL + ControlNet (Img2Img)
|
|
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
|
|
"stabilityai/stable-diffusion-xl-base-1.0",
|
|
controlnet=controlnet,
|
|
torch_dtype=torch.float16
|
|
).to("cuda")
|
|
|
|
# Img2Img "semplice"
|
|
pipeNoImg = StableDiffusionImg2ImgPipeline.from_pretrained(
|
|
"runwayml/stable-diffusion-v1-5",
|
|
torch_dtype=torch.float16
|
|
).to("cuda")
|
|
|
|
# Config scheduler
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
|
pipe.enable_xformers_memory_efficient_attention()
|
|
|
|
return pipe, pipeNoImg
|
|
|
|
|
|
# ============================
|
|
# 🔤 Prompt da JSON
|
|
# ============================
|
|
def json_to_prompt(state_json):
|
|
plant = state_json.get("plant", "plant")
|
|
prompt = """A highly realistic photograph of the same grape plant in its pot, observed a month later.The plant is slightly taller and thinner, with fewer leaves overall.Most leaves are pale yellow-green, with some edges turning brown and slightly dried,showing early signs of stress and aging.The background and lighting must remain unchanged, with sharp details, natural shadows, and realistic leaf textures."""
|
|
|
|
# Esempi di parametri → arricchisci con altri se ti servono
|
|
if "leaf_color" in state_json:
|
|
prompt += f", leaves are {state_json['leaf_color']}"
|
|
if "health" in state_json:
|
|
mapping = {
|
|
"good": "looking healthy",
|
|
"average": "in average condition",
|
|
"bad": "looking unhealthy or wilted"
|
|
}
|
|
prompt += f", {mapping.get(state_json['health'], state_json['health'])}"
|
|
|
|
# Puoi rimettere il prompt fisso che avevi se vuoi sempre lo stesso stile
|
|
# prompt = """A highly realistic photograph of the same grape plant ..."""
|
|
|
|
return prompt
|
|
|
|
|
|
# ============================
|
|
# 🌿 Generazione immagine
|
|
# ============================
|
|
def generate_plant_image(pipe, prompt, use_img=False, base_image=None, control_image=None, controlnet_scale=1.0, strength=0.2):
|
|
if use_img and base_image is None:
|
|
raise ValueError("Hai impostato use_img=True ma non hai passato base_image!")
|
|
|
|
if use_img:
|
|
# --- Img2Img con ControlNet ---
|
|
result = pipe(
|
|
prompt=prompt,
|
|
image=base_image,
|
|
control_image=control_image,
|
|
guidance_scale=7.5,
|
|
controlnet_conditioning_scale=controlnet_scale,
|
|
strength=strength,
|
|
num_inference_steps=40,
|
|
negative_prompt="cartoon, illustration, painting, extra plants, changed background"
|
|
).images[0]
|
|
else:
|
|
# --- Txt2Img simulato ---
|
|
noise = (torch.rand(3, 384, 384) * 255).byte().numpy().transpose(1, 2, 0)
|
|
noise_img = Image.fromarray(noise, mode="RGB")
|
|
result = pipe(
|
|
prompt=prompt,
|
|
image=noise_img,
|
|
strength=1.0,
|
|
guidance_scale=7.0,
|
|
num_inference_steps=25
|
|
).images[0]
|
|
|
|
return result
|
|
|
|
|
|
# ============================
|
|
# 🏁 MAIN
|
|
# ============================
|
|
def main():
|
|
# Carica pipelines
|
|
pipe, pipeNoImg = load_pipelines()
|
|
|
|
# Carica stato pianta (JSON)
|
|
with open("prediction.json", "r", encoding="utf-8") as f:
|
|
json_state = json.load(f)
|
|
|
|
prompt = json_to_prompt(json_state)
|
|
print(f"Prompt generato:\n{prompt}")
|
|
|
|
# Carica immagini di base
|
|
base_img = Image.open("baseImg.jpeg").convert("RGB").resize((512, 512))
|
|
plantglImg = Image.open("controlImg.png").convert("RGB").resize((512, 512))
|
|
|
|
# Genera immagine
|
|
output = generate_plant_image(
|
|
pipe,
|
|
prompt,
|
|
use_img=True,
|
|
base_image=base_img,
|
|
control_image=plantglImg,
|
|
controlnet_scale=0.5,
|
|
strength=0.2
|
|
)
|
|
|
|
# Mostra + salva
|
|
output.show()
|
|
output.save("plant_output.png")
|
|
print("✅ Immagine salvata come plant_output.png")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|