Added learning model for image generation

This commit is contained in:
Nicolò 2025-08-02 10:50:03 +02:00
parent dbf492898c
commit 3930e3f4a2
19 changed files with 173 additions and 408 deletions

54
app.py
View file

@ -1,3 +1,4 @@
import os
import streamlit as st
import pandas as pd
import joblib
@ -6,10 +7,10 @@ import torch
from torchvision import models
from llama_cpp import Llama
from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler
from image_generation import generate_image
st.set_page_config(page_title="Plant Growth Predictor", layout="centered")
st.title("🌱 Plant Growth Predictor")
st.set_page_config(page_title="GreenThumber", layout="centered")
st.title("🌱 GreenThumber")
@st.cache_resource
@ -35,6 +36,7 @@ def load_mistral_model():
return llm
llm = load_mistral_model()
# Generate a description using the Mistral model
def generate_growth_description(plant_type, soil_type, sunlight_hours, water_frequency,
@ -52,24 +54,22 @@ def generate_growth_description(plant_type, soil_type, sunlight_hours, water_fre
f"- Humidity: {humidity}%\n"
f"### Response:\n"
)
output = llm(prompt, max_tokens=250, stop=["###"])
output = llm(prompt, max_tokens=100, stop=["###"])
return output["choices"][0]["text"].strip()
def generate_condition_image(description: str, input_image: Image.Image) -> Image.Image:
input_image = input_image.convert("RGB").resize((512, 512))
st.spinner("Generating predicted plant condition image...")
st.header("Plant Info")
plant_input_mode = st.radio("How would you like to provide plant info?", ("Type name", "Upload image"))
plant_type = None
uploaded_image = None
if plant_input_mode == "Type name":
plant_type = st.selectbox("Select Plant Type", ["Basil", "Tomato", "Lettuce", "Rosemary", "Other"])
plant_type = st.selectbox("Select Plant Type", ["Basil", "Tomato", "Lettuce"])
plant_age = st.number_input("Enter Plant Age (in days)", min_value=1, max_value=365, value=25)
elif plant_input_mode == "Upload image":
plant_type = st.selectbox("Select Plant Type", ["Basil", "Tomato", "Lettuce", "Rosemary", "Other"])
plant_type = st.selectbox("Select Plant Type", ["Basil", "Tomato", "Lettuce"])
plant_age = st.number_input("Enter Plant Age (in days)", min_value=1, max_value=365, value=30)
image_file = st.file_uploader("Upload an image of your plant", type=["jpg", "jpeg", "png"])
if image_file:
uploaded_image = Image.open(image_file)
@ -98,7 +98,7 @@ with col2:
additional_info = st.text_area("Feel free to include any additional detail")
# Prediction + Description + Image Generation
if st.button("Predict Growth Milestone and Generate Description & Image"):
if st.button("Start Prediction"):
if plant_type and plant_type.strip() != "":
if plant_input_mode == "Upload image" and uploaded_image is None:
st.warning("Please upload a plant image to proceed.")
@ -110,40 +110,14 @@ if st.button("Predict Growth Milestone and Generate Description & Image"):
)
st.subheader(f"📝 Predicted Plant Condition in {days} Days:")
st.write(description)
# Use uploaded image if available, else placeholder or skip image generation
if plant_input_mode == "Upload image" and uploaded_image:
manipulated_img = generate_condition_image(description, uploaded_image)
st.image(manipulated_img, caption="Predicted Plant Condition Image")
else:
st.info("Image prediction requires uploading a plant image.")
manipulated_img = generate_image(plant_type, description, plant_age)
st.image(manipulated_img, caption="Predicted Plant Condition Image")
else:
st.warning("Please select or enter a plant type.")
@st.cache_resource
def load_sd():
model_id = "stabilityai/stable-diffusion-2-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_attention_slicing()
pipe.to("cpu")
return pipe
pipe = load_sd()
st.write(description)
with st.spinner("Generating plant image..."):
results = pipe(
description,
num_inference_steps=50,
guidance_scale=3.5,
height=512,
width=512
)
image = results.images[0]
st.image(image, caption="Predicted Plant Condition", use_column_width=True)
st.markdown("---")
st.caption("Made with ❤️ by Sandwich Craftz.")