kendrickfff commited on
Commit
ddb0e33
·
verified ·
1 Parent(s): 91e0dca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -39
app.py CHANGED
@@ -1,19 +1,19 @@
 
1
  import gradio as gr
2
  from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
3
  from PIL import Image
4
  import torch
5
  from torchvision import models, transforms
6
- import json
7
- import requests
8
 
9
- # Initialize the chat model with Hugging Face-specific environment variables
10
- llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
 
11
 
12
- # Load a pre-trained ResNet50 model for image classification
13
  model = models.resnet50(pretrained=True)
14
- model.eval()
15
 
16
- # Transformation pipeline for image preprocessing
17
  transform = transforms.Compose([
18
  transforms.Resize(256),
19
  transforms.CenterCrop(224),
@@ -21,62 +21,68 @@ transform = transforms.Compose([
21
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
  ])
23
 
24
- # Load ImageNet labels
25
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
26
- labels = json.loads(requests.get(LABELS_URL).text)
 
 
 
 
 
 
27
 
28
- # Global chat history variable
29
- chat_history = []
 
30
 
31
- def chat_with_gemini(message):
32
- global chat_history
33
- # Get a response from the language model
34
  bot_response = llm.predict(message)
35
  chat_history.append((message, bot_response))
36
- return chat_history
 
37
 
38
- def analyze_image(image_path):
39
- global chat_history
40
- # Open, preprocess, and classify the image
41
  image = Image.open(image_path).convert("RGB")
42
  image_tensor = transform(image).unsqueeze(0)
43
-
 
44
  with torch.no_grad():
45
  outputs = model(image_tensor)
46
  _, predicted_idx = outputs.max(1)
47
 
 
48
  label = labels[predicted_idx.item()]
 
 
49
  bot_response = f"The image seems to be: {label}."
50
  chat_history.append(("Uploaded an image for analysis", bot_response))
51
- return chat_history
 
52
 
53
- # Build the Gradio interface
54
- with gr.Blocks() as demo:
55
  gr.Markdown("# Ken Chatbot")
56
  gr.Markdown("Ask me anything or upload an image for analysis!")
57
 
58
- # Chatbot display without "User" or "Bot" labels
59
  chatbot = gr.Chatbot(elem_id="chatbot")
60
-
61
  # User input components
62
- msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...", show_label=False)
63
  send_btn = gr.Button("Send")
64
  img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
65
 
66
- # Define interactions
67
- def handle_text_message(message):
68
- return chat_with_gemini(message)
69
-
70
- def handle_image_upload(image_path):
71
- return analyze_image(image_path)
72
 
73
- # Set up Gradio components with Enter key for sending
74
- msg.submit(handle_text_message, msg, chatbot)
75
- send_btn.click(handle_text_message, msg, chatbot)
76
- send_btn.click(lambda: "", None, msg) # Clear input field
77
- img_upload.change(handle_image_upload, img_upload, chatbot)
78
 
79
- # Custom CSS for styling without usernames
80
  gr.HTML("""
81
  <style>
82
  #chatbot .message-container {
@@ -104,5 +110,5 @@ with gr.Blocks() as demo:
104
  </style>
105
  """)
106
 
107
- # Launch for Hugging Face Spaces
108
- demo.launch()
 
1
+ import os
2
  import gradio as gr
3
  from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
4
  from PIL import Image
5
  import torch
6
  from torchvision import models, transforms
 
 
7
 
8
+ # Set up the environment for Google Generative AI
9
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./firm-catalyst-437006-s4-407500537db5.json"
10
+ llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
11
 
12
+ # Load a pre-trained ResNet50 model for image analysis
13
  model = models.resnet50(pretrained=True)
14
+ model.eval() # Set the model to evaluation mode
15
 
16
+ # Define the transformation for the image
17
  transform = transforms.Compose([
18
  transforms.Resize(256),
19
  transforms.CenterCrop(224),
 
21
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
  ])
23
 
24
+ # Load the ImageNet labels
25
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
26
+ labels = None
27
+
28
+ if not os.path.exists("imagenet_labels.json"):
29
+ import requests
30
+ response = requests.get(LABELS_URL)
31
+ with open("imagenet_labels.json", "wb") as f:
32
+ f.write(response.content)
33
 
34
+ import json
35
+ with open("imagenet_labels.json") as f:
36
+ labels = json.load(f)
37
 
38
+ def chat_with_gemini(message, chat_history):
39
+ # Generate a response from the language model
 
40
  bot_response = llm.predict(message)
41
  chat_history.append((message, bot_response))
42
+
43
+ return chat_history, chat_history
44
 
45
+ def analyze_image(image_path, chat_history):
46
+ # Load and preprocess the image
 
47
  image = Image.open(image_path).convert("RGB")
48
  image_tensor = transform(image).unsqueeze(0)
49
+
50
+ # Predict the image class
51
  with torch.no_grad():
52
  outputs = model(image_tensor)
53
  _, predicted_idx = outputs.max(1)
54
 
55
+ # Retrieve the label
56
  label = labels[predicted_idx.item()]
57
+
58
+ # Respond with the classification result
59
  bot_response = f"The image seems to be: {label}."
60
  chat_history.append(("Uploaded an image for analysis", bot_response))
61
+
62
+ return chat_history, chat_history
63
 
64
+ # Create Gradio interface
65
+ with gr.Blocks() as iface:
66
  gr.Markdown("# Ken Chatbot")
67
  gr.Markdown("Ask me anything or upload an image for analysis!")
68
 
69
+ # Chatbot component without usernames
70
  chatbot = gr.Chatbot(elem_id="chatbot")
71
+
72
  # User input components
73
+ msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...")
74
  send_btn = gr.Button("Send")
75
  img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
76
 
77
+ # State for chat history
78
+ state = gr.State([])
 
 
 
 
79
 
80
+ # Define interactions
81
+ send_btn.click(chat_with_gemini, [msg, state], [chatbot, state]) # Handle text input
82
+ send_btn.click(lambda: "", None, msg) # Clear textbox
83
+ img_upload.change(analyze_image, [img_upload, state], [chatbot, state]) # Handle image uploads
 
84
 
85
+ # Custom CSS for styling chat bubbles without usernames
86
  gr.HTML("""
87
  <style>
88
  #chatbot .message-container {
 
110
  </style>
111
  """)
112
 
113
+ # Launch the Gradio interface
114
+ iface.launch(debug=True)