team-6/app.py
2025-08-02 13:12:37 +02:00

115 lines
No EOL
4.1 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import traceback
from flask import Flask, request, jsonify
from PIL import Image
import numpy as np
import torch
import clip
import tensorflow as tf
# --- Configuration ---
class Config:
"""Groups all required configuration variables in one place."""
PLANT_MODEL_PATH = 'models/BestModel.keras'
PLANT_CLASSES = ['tomato', 'basil', 'mint', 'lettuce', 'rosemary', 'strawberry']
IMG_SIZE = (384, 384)
# Health analysis prompt
HEALTH_PROMPTS = [
"a photo of a healthy {plant} plant with vibrant green leaves",
"a photo of a sick {plant} plant with yellow spots or discoloration",
"a photo of a dehydrated {plant} plant with wilted or drooping leaves",
"a photo of a dead {plant} plant with brown, dry, or crispy leaves"
]
HEALTH_LABELS = ["Healthy", "Diseased", "Dehydrated", "Dead"]
# --- Application Setup ---
app = Flask(__name__)
# --- Model Loading ---
def load_models():
"""Loads and initializes all required machine learning models."""
# Load plant identification model
print("1. Loading plant identification model...")
plant_model = tf.keras.models.load_model(Config.PLANT_MODEL_PATH)
# Load model for health analysis
print("2. Loading CLIP model for health analysis...")
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
print("\nAll models loaded successfully.")
return plant_model, clip_model, clip_preprocess, device
# Load models
plant_model, clip_model, clip_preprocess, device = load_models()
# --- Core ML Functions ---
def identify_plant(image):
"""Plant identification"""
img = image.resize(Config.IMG_SIZE)
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)
img_array = tf.keras.applications.mobilenet_v2.preprocess_input(img_array)
preds = plant_model.predict(img_array, verbose=0)
best_idx = np.argmax(preds[0])
plant_name = Config.PLANT_CLASSES[best_idx]
confidence = float(np.max(preds))
return plant_name, confidence
def assess_health(plant_name, image):
"""Plant health"""
prompts = [p.format(plant=plant_name) for p in Config.HEALTH_PROMPTS]
image_input = clip_preprocess(image).unsqueeze(0).to(device)
text_tokens = clip.tokenize(prompts).to(device)
with torch.no_grad():
image_features = clip_model.encode_image(image_input)
text_features = clip_model.encode_text(text_tokens)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
probs = similarity.cpu().numpy()[0]
status = Config.HEALTH_LABELS[np.argmax(probs)]
confidence = float(np.max(probs))
probabilities = {label: float(p) for label, p in zip(Config.HEALTH_LABELS, probs)}
return status, confidence, probabilities
# --- API Endpoint ---
@app.route('/analyze', methods=['POST'])
def analyze_plant_image():
"""Image analysis."""
if 'image' not in request.files:
return jsonify({'error': 'No image file provided'}), 400
try:
file = request.files['image']
image = Image.open(io.BytesIO(file.read())).convert("RGB")
plant_name, plant_conf = identify_plant(image.copy())
health_status, health_conf, health_probs = assess_health(plant_name, image)
return jsonify({
'plant_species': plant_name,
'identification_confidence': f"{plant_conf:.2%}",
'health_status': health_status,
'health_confidence': f"{health_conf:.2%}",
'health_breakdown': health_probs
})
except Exception as e:
print("An error occurred:", str(e))
traceback.print_exc()
return jsonify({'error': 'An internal server error occurred.'}), 500
# --- Main Execution ---
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)