Added learning model for image generation
This commit is contained in:
parent
dbf492898c
commit
3930e3f4a2
19 changed files with 173 additions and 408 deletions
54
app.py
54
app.py
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
import joblib
|
||||
|
@ -6,10 +7,10 @@ import torch
|
|||
from torchvision import models
|
||||
from llama_cpp import Llama
|
||||
from diffusers import DiffusionPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler
|
||||
from image_generation import generate_image
|
||||
|
||||
|
||||
st.set_page_config(page_title="Plant Growth Predictor", layout="centered")
|
||||
st.title("🌱 Plant Growth Predictor")
|
||||
st.set_page_config(page_title="GreenThumber", layout="centered")
|
||||
st.title("🌱 GreenThumber")
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
|
@ -35,6 +36,7 @@ def load_mistral_model():
|
|||
return llm
|
||||
|
||||
llm = load_mistral_model()
|
||||
|
||||
|
||||
# Generate a description using the Mistral model
|
||||
def generate_growth_description(plant_type, soil_type, sunlight_hours, water_frequency,
|
||||
|
@ -52,24 +54,22 @@ def generate_growth_description(plant_type, soil_type, sunlight_hours, water_fre
|
|||
f"- Humidity: {humidity}%\n"
|
||||
f"### Response:\n"
|
||||
)
|
||||
output = llm(prompt, max_tokens=250, stop=["###"])
|
||||
output = llm(prompt, max_tokens=100, stop=["###"])
|
||||
return output["choices"][0]["text"].strip()
|
||||
|
||||
|
||||
def generate_condition_image(description: str, input_image: Image.Image) -> Image.Image:
|
||||
input_image = input_image.convert("RGB").resize((512, 512))
|
||||
st.spinner("Generating predicted plant condition image...")
|
||||
|
||||
st.header("Plant Info")
|
||||
plant_input_mode = st.radio("How would you like to provide plant info?", ("Type name", "Upload image"))
|
||||
plant_type = None
|
||||
uploaded_image = None
|
||||
|
||||
if plant_input_mode == "Type name":
|
||||
plant_type = st.selectbox("Select Plant Type", ["Basil", "Tomato", "Lettuce", "Rosemary", "Other"])
|
||||
plant_type = st.selectbox("Select Plant Type", ["Basil", "Tomato", "Lettuce"])
|
||||
plant_age = st.number_input("Enter Plant Age (in days)", min_value=1, max_value=365, value=25)
|
||||
|
||||
elif plant_input_mode == "Upload image":
|
||||
plant_type = st.selectbox("Select Plant Type", ["Basil", "Tomato", "Lettuce", "Rosemary", "Other"])
|
||||
plant_type = st.selectbox("Select Plant Type", ["Basil", "Tomato", "Lettuce"])
|
||||
plant_age = st.number_input("Enter Plant Age (in days)", min_value=1, max_value=365, value=30)
|
||||
image_file = st.file_uploader("Upload an image of your plant", type=["jpg", "jpeg", "png"])
|
||||
if image_file:
|
||||
uploaded_image = Image.open(image_file)
|
||||
|
@ -98,7 +98,7 @@ with col2:
|
|||
additional_info = st.text_area("Feel free to include any additional detail")
|
||||
|
||||
# Prediction + Description + Image Generation
|
||||
if st.button("Predict Growth Milestone and Generate Description & Image"):
|
||||
if st.button("Start Prediction"):
|
||||
if plant_type and plant_type.strip() != "":
|
||||
if plant_input_mode == "Upload image" and uploaded_image is None:
|
||||
st.warning("Please upload a plant image to proceed.")
|
||||
|
@ -110,40 +110,14 @@ if st.button("Predict Growth Milestone and Generate Description & Image"):
|
|||
)
|
||||
st.subheader(f"📝 Predicted Plant Condition in {days} Days:")
|
||||
st.write(description)
|
||||
|
||||
# Use uploaded image if available, else placeholder or skip image generation
|
||||
if plant_input_mode == "Upload image" and uploaded_image:
|
||||
manipulated_img = generate_condition_image(description, uploaded_image)
|
||||
st.image(manipulated_img, caption="Predicted Plant Condition Image")
|
||||
else:
|
||||
st.info("Image prediction requires uploading a plant image.")
|
||||
|
||||
manipulated_img = generate_image(plant_type, description, plant_age)
|
||||
st.image(manipulated_img, caption="Predicted Plant Condition Image")
|
||||
|
||||
else:
|
||||
st.warning("Please select or enter a plant type.")
|
||||
|
||||
@st.cache_resource
|
||||
def load_sd():
|
||||
model_id = "stabilityai/stable-diffusion-2-1"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.to("cpu")
|
||||
return pipe
|
||||
|
||||
pipe = load_sd()
|
||||
|
||||
st.write(description)
|
||||
|
||||
with st.spinner("Generating plant image..."):
|
||||
results = pipe(
|
||||
description,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=3.5,
|
||||
height=512,
|
||||
width=512
|
||||
)
|
||||
image = results.images[0]
|
||||
st.image(image, caption="Predicted Plant Condition", use_column_width=True)
|
||||
|
||||
st.markdown("---")
|
||||
st.caption("Made with ❤️ by Sandwich Craftz.")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue