yolac commited on
Commit
1b76d00
·
verified ·
1 Parent(s): 5d0d983

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -66
app.py CHANGED
@@ -1,67 +1,57 @@
1
- import os
2
- import requests
3
- from datasets import load_dataset
4
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  import torch
6
-
7
- # Step 1: Set up environment and check paths
8
- MODEL_PATH = "yolac/BacterialMorphologyClassification"
9
- DATASET_PATH = "yolac/BacterialMorphologyClassification"
10
-
11
- # Step 2: Load the model
12
- try:
13
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH, num_labels=3)
14
- print("Model loaded successfully.")
15
- except Exception as e:
16
- print(f"Failed to load the model. Error: {e}")
17
-
18
- # Step 3: Load the tokenizer (if needed)
19
- try:
20
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
21
- print("Tokenizer loaded successfully.")
22
- except Exception as e:
23
- print(f"Failed to load the tokenizer. Error: {e}")
24
-
25
- # Step 4: Load the dataset
26
- try:
27
- dataset = load_dataset(DATASET_PATH, split="train")
28
- print("Dataset loaded successfully.")
29
- except Exception as e:
30
- print(f"Failed to load the dataset. Error: {e}")
31
-
32
- # Step 5: Preprocess and prepare data for model input (example code)
33
- def preprocess_data(example):
34
- # Add any necessary preprocessing steps here, e.g., tokenization
35
- return tokenizer(example['text'], padding="max_length", truncation=True)
36
-
37
- # Apply preprocessing
38
- dataset = dataset.map(preprocess_data, batched=True)
39
-
40
- # Step 6: Set up training arguments (use the `Trainer` class if needed)
41
- from transformers import Trainer, TrainingArguments
42
-
43
- training_args = TrainingArguments(
44
- output_dir="./results",
45
- evaluation_strategy="epoch",
46
- learning_rate=2e-5,
47
- per_device_train_batch_size=8,
48
- per_device_eval_batch_size=8,
49
- num_train_epochs=3,
50
- weight_decay=0.01,
51
- logging_dir="./logs",
52
- )
53
-
54
- # Initialize the Trainer
55
- trainer = Trainer(
56
- model=model,
57
- args=training_args,
58
- train_dataset=dataset,
59
- tokenizer=tokenizer,
60
- )
61
-
62
- # Step 7: Train the model
63
- try:
64
- trainer.train()
65
- print("Training completed successfully.")
66
- except Exception as e:
67
- print(f"Failed to train the model. Error: {e}")
 
 
 
 
 
1
  import torch
2
+ from fastapi import FastAPI, UploadFile, File
3
+ from pydantic import BaseModel
4
+ from io import BytesIO
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
8
+
9
+ # Initialize the FastAPI app
10
+ app = FastAPI()
11
+
12
+ # Load the model and feature extractor
13
+ model = AutoModelForImageClassification.from_pretrained("yolac/BacterialMorphologyClassification")
14
+ feature_extractor = AutoFeatureExtractor.from_pretrained("yolac/BacterialMorphologyClassification")
15
+
16
+ # Define the image preprocessing transform
17
+ preprocess = transforms.Compose([
18
+ transforms.Resize((224, 224)), # Adjust size as needed for your model
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
21
+ ])
22
+
23
+ # Helper function to preprocess the image
24
+ def preprocess_image(image: Image.Image):
25
+ image = image.convert("RGB") # Ensure the image is in RGB format
26
+ image = preprocess(image)
27
+ image = image.unsqueeze(0) # Add a batch dimension
28
+ return image
29
+
30
+ # Define the prediction endpoint
31
+ @app.post("/predict/")
32
+ async def predict(file: UploadFile = File(...)):
33
+ try:
34
+ # Read the uploaded image
35
+ image_data = await file.read()
36
+ image = Image.open(BytesIO(image_data))
37
+
38
+ # Preprocess the image
39
+ image_tensor = preprocess_image(image)
40
+
41
+ # Perform inference
42
+ model.eval()
43
+ with torch.no_grad():
44
+ outputs = model(image_tensor)
45
+ logits = outputs.logits # Get the model's raw output
46
+ predicted_class_idx = torch.argmax(logits, dim=1).item()
47
+
48
+ # Map the predicted class index to the class labels
49
+ class_labels = ["Cocci", "Bacilli", "Spirilla"] # Replace with your actual class labels
50
+ predicted_class = class_labels[predicted_class_idx]
51
+
52
+ return {"predicted_class": predicted_class}
53
+
54
+ except Exception as e:
55
+ return {"error": str(e)}
56
+
57
+ # Run the app using the command: uvicorn app:app --reload