Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import os | |
import json | |
import numpy as np | |
import pickle | |
import zipfile | |
from io import BytesIO | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import accuracy_score | |
# Directories | |
UPLOAD_DIR = "uploaded_images" | |
MODEL_PATH = os.path.join(UPLOAD_DIR, "model.pkl") | |
LABELS_PATH = os.path.join(UPLOAD_DIR, "labels.json") | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
# Initialize model and labels | |
if os.path.exists(MODEL_PATH): | |
with open(MODEL_PATH, "rb") as f: | |
model = pickle.load(f) | |
else: | |
model = RandomForestClassifier() | |
if os.path.exists(LABELS_PATH): | |
with open(LABELS_PATH, "r") as f: | |
labels = json.load(f) | |
else: | |
labels = {} | |
# Helper function to save uploaded images | |
def save_uploaded_image(uploaded_file): | |
file_path = os.path.join(UPLOAD_DIR, uploaded_file.name) | |
with open(file_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
return file_path | |
# Feature extraction function | |
def extract_features(file_path): | |
try: | |
with Image.open(file_path) as img: | |
return np.array(img.resize((64, 64))).flatten() | |
except Exception as e: | |
print(f"Error processing image {file_path}: {e}") | |
return None | |
# Train the model | |
def train_model(training_data): | |
global model, labels | |
features, targets = [], [] | |
for file_path, label in training_data.items(): | |
feature = extract_features(file_path) | |
if feature is not None: | |
features.append(feature) | |
targets.append(label) | |
if features: | |
features = np.array(features) | |
targets = np.array(targets) | |
model.fit(features, targets) | |
with open(MODEL_PATH, "wb") as f: | |
pickle.dump(model, f) | |
with open(LABELS_PATH, "w") as f: | |
json.dump(labels, f) | |
else: | |
print("No valid features found for training.") | |
# Classify an image | |
def classify_image(file_path): | |
global model | |
features = extract_features(file_path) | |
if features is not None: | |
features = features.reshape(1, -1) | |
return model.predict(features)[0] | |
else: | |
return "Invalid Image" | |
# Create a sample dataset for download | |
def create_sample_dataset(): | |
sample_dir = "sample_dataset" | |
os.makedirs(sample_dir, exist_ok=True) | |
# Create Human and Alien directories | |
human_dir = os.path.join(sample_dir, "Human") | |
alien_dir = os.path.join(sample_dir, "Alien") | |
os.makedirs(human_dir, exist_ok=True) | |
os.makedirs(alien_dir, exist_ok=True) | |
# Add placeholder images (replace with real images in a practical application) | |
for i in range(1, 4): | |
human_image_path = os.path.join(human_dir, f"human_{i}.jpg") | |
alien_image_path = os.path.join(alien_dir, f"alien_{i}.jpg") | |
Image.new('RGB', (64, 64), color=(255, 0, 0)).save(human_image_path) | |
Image.new('RGB', (64, 64), color=(0, 255, 0)).save(alien_image_path) | |
# Create a ZIP file for download | |
zip_buffer = BytesIO() | |
with zipfile.ZipFile(zip_buffer, "w") as zip_file: | |
for folder_name, subfolders, filenames in os.walk(sample_dir): | |
for filename in filenames: | |
file_path = os.path.join(folder_name, filename) | |
arcname = os.path.relpath(file_path, sample_dir) | |
zip_file.write(file_path, arcname) | |
return zip_buffer.getvalue() | |
# Streamlit app | |
def main(): | |
st.title("Human or Alien Identification") | |
st.markdown( | |
""" | |
Welcome to the **Human or Alien Identification App**! Here's what you can do: | |
- **Identify Image:** Upload an image and classify it as "Human" or "Alien." The classifications you save will be added to the training data. | |
- **Train Model:** Review and manage the images already classified as "Human" or "Alien." Upload additional images to improve the training dataset. | |
- **Download Sample Dataset:** Download a pre-structured dataset to use for training and classification. | |
""" | |
) | |
tab1, tab2, tab3 = st.tabs(["Identify Image", "Train Model", "Download Sample Dataset"]) | |
with tab1: | |
st.header("Identify Image") | |
uploaded_files = st.file_uploader("Upload Images to Identify", type=["jpg", "jpeg", "png"], accept_multiple_files=True) | |
if uploaded_files: | |
results = {} | |
for uploaded_file in uploaded_files: | |
try: | |
st.image(uploaded_file, caption=f"Uploaded Image: {uploaded_file.name}", use_container_width=True) | |
file_path = save_uploaded_image(uploaded_file) | |
prediction = classify_image(file_path) | |
results[file_path] = prediction | |
except Exception as e: | |
st.error(f"Error processing file {uploaded_file.name}: {e}") | |
st.subheader("Classification Results") | |
for file_path, label in results.items(): | |
st.markdown( | |
f"<p style='font-size:20px; color:blue; background-color:lightyellow; padding:10px; border-radius:5px;'>Image: {os.path.basename(file_path)} - Classified as: {label}</p>", | |
unsafe_allow_html=True, | |
) | |
with tab2: | |
st.header("Train Model") | |
st.subheader("Upload Images for Training") | |
training_files = st.file_uploader("Upload Training Images", type=["jpg", "jpeg", "png"], accept_multiple_files=True) | |
if training_files: | |
training_data = {} | |
for training_file in training_files: | |
try: | |
st.image(training_file, caption=f"Uploaded Training Image: {training_file.name}", use_container_width=True) | |
file_path = save_uploaded_image(training_file) | |
label = st.radio( | |
f"Classify {training_file.name}", ["Human", "Alien"], index=0, key=training_file.name | |
) | |
training_data[file_path] = label | |
except Exception as e: | |
st.error(f"Error processing file {training_file.name}: {e}") | |
if st.button("Save Training Data"): | |
labels.update(training_data) | |
train_model(training_data) | |
st.success("Training data has been saved and the model updated.") | |
st.subheader("Bulk Train with Existing Dataset") | |
dataset_file = st.file_uploader("Upload a JSON file containing labeled data", type=["json"]) | |
if dataset_file: | |
dataset = json.load(dataset_file) | |
training_data = {} | |
for label, image_paths in dataset.items(): | |
for image_path in image_paths: | |
if os.path.exists(image_path): | |
training_data[image_path] = label | |
if st.button("Train Model with Dataset"): | |
labels.update(training_data) | |
train_model(training_data) | |
st.success("Model has been trained with the uploaded dataset.") | |
subtab1, subtab2 = st.tabs(["Humans", "Aliens"]) | |
with subtab1: | |
st.subheader("Human Images") | |
human_images = [ | |
img for img, lbl in labels.items() if lbl == "Human" | |
] | |
if human_images: | |
st.info("These images have already been trained.") | |
for image_path in human_images: | |
try: | |
st.image(image_path, caption=f"Image: {os.path.basename(image_path)}", use_container_width=True) | |
except Exception as e: | |
st.error(f"Error displaying image {os.path.basename(image_path)}: {e}") | |
else: | |
st.warning("No human images found for training.") | |
with subtab2: | |
st.subheader("Alien Images") | |
alien_images = [ | |
img for img, lbl in labels.items() if lbl == "Alien" | |
] | |
if alien_images: | |
st.info("These images have already been trained.") | |
for image_path in alien_images: | |
try: | |
st.image(image_path, caption=f"Image: {os.path.basename(image_path)}", use_container_width=True) | |
except Exception as e: | |
st.error(f"Error displaying image {os.path.basename(image_path)}: {e}") | |
else: | |
st.warning("No alien images found for training.") | |
with tab3: | |
st.header("Download Sample Dataset") | |
if st.button("Download Sample Dataset"): | |
sample_dataset = create_sample_dataset() | |
st.download_button( | |
label="Click to Download", | |
data=sample_dataset, | |
file_name="sample_dataset.zip", | |
mime="application/zip" | |
) | |
if __name__ == "__main__": | |
main() | |