Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import torch.nn as nn
|
|
7 |
|
8 |
# ----- Model Definition -----
|
9 |
class CustomDialoGPT(nn.Module):
|
10 |
-
def __init__(self, vocab_size, n_embd=768, n_head=8, n_layer=8): # <---- FORCE n_embd, n_head, n_layer to match
|
11 |
super().__init__()
|
12 |
|
13 |
config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium",
|
@@ -39,14 +39,13 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
39 |
|
40 |
# Load tokenizer
|
41 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
42 |
-
vocab_size = len(tokenizer)
|
43 |
|
44 |
# Initialize model with fixed parameters to match checkpoint
|
45 |
-
n_embd=768
|
46 |
-
n_head=8
|
47 |
-
n_layer=8
|
48 |
-
model = CustomDialoGPT(vocab_size, n_embd, n_head, n_layer)
|
49 |
-
|
50 |
|
51 |
# Download and load model weights
|
52 |
try:
|
@@ -54,13 +53,11 @@ try:
|
|
54 |
checkpoint = torch.load(pth_filepath, map_location=device)
|
55 |
|
56 |
# Handle different checkpoint saving formats if needed.
|
57 |
-
# If your checkpoint is just the state_dict, load it directly.
|
58 |
if 'model_state_dict' in checkpoint:
|
59 |
model.load_state_dict(checkpoint['model_state_dict'])
|
60 |
elif 'state_dict' in checkpoint:
|
61 |
model.load_state_dict(checkpoint['state_dict'])
|
62 |
else:
|
63 |
-
# Assume checkpoint is just the raw state_dict
|
64 |
model.load_state_dict(checkpoint)
|
65 |
|
66 |
print(f"Successfully loaded model weights from {model_repo}/{model_filename}")
|
@@ -72,60 +69,38 @@ except Exception as e:
|
|
72 |
model.to(device)
|
73 |
model.eval() # Set model to evaluation mode
|
74 |
|
75 |
-
def chat_with_model(user_input
|
76 |
-
"""Chatbot function to interact with the loaded model."""
|
77 |
-
|
78 |
-
input_text = tokenizer.eos_token.join(history_transformer_format + [user_input])
|
79 |
-
|
80 |
-
input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
|
81 |
|
82 |
with torch.no_grad():
|
83 |
-
output = model.transformer.generate(
|
84 |
-
inputs=input_ids,
|
85 |
-
max_length=
|
86 |
pad_token_id=tokenizer.eos_token_id,
|
87 |
temperature=0.7,
|
88 |
-
top_p=0.9
|
|
|
89 |
)
|
90 |
|
91 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
|
92 |
|
93 |
-
|
94 |
-
#
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
# Explicitly format history as list of tuples:
|
99 |
-
history.append((user_input, bot_response))
|
100 |
-
|
101 |
-
# Reformat history for Gradio Chatbot - Ensure tuples within a list
|
102 |
-
chatbot_history = []
|
103 |
-
for turn in history:
|
104 |
-
chatbot_history.append(turn) # Each turn is already a tuple (user_msg, bot_msg)
|
105 |
-
|
106 |
-
return bot_response, chatbot_history # Return chatbot_history for Gradio
|
107 |
|
108 |
-
|
109 |
-
"""Convert gradio history to a list of strings for transformer input."""
|
110 |
-
history_formatted = []
|
111 |
-
for user_msg, bot_msg in history:
|
112 |
-
history_formatted.append(user_msg)
|
113 |
-
history_formatted.append(bot_msg)
|
114 |
-
return history_formatted
|
115 |
|
116 |
|
117 |
iface = gr.Interface( # Changed from gr.ChatInterface to gr.Interface
|
118 |
fn=chat_with_model,
|
119 |
inputs=gr.Textbox(placeholder="Type your message here..."), # Explicitly define inputs as gr.Textbox
|
120 |
-
outputs=gr.
|
121 |
-
title="ElapticAI-1a Chatbot",
|
122 |
-
description="Simple chatbot interface for ElapticAI-1a model
|
123 |
-
examples=[ # Corrected examples format
|
124 |
-
["Hello", "Hi there!"], # Example 1: [user_input, bot_response]
|
125 |
-
["How are you?", "I am doing well, thank you."], # Example 2
|
126 |
-
["Tell me a joke", "Why don't scientists trust atoms? Because they make up everything! 😄"] # Example 3
|
127 |
-
]
|
128 |
)
|
129 |
|
130 |
if __name__ == "__main__":
|
131 |
-
iface.launch()
|
|
|
7 |
|
8 |
# ----- Model Definition -----
|
9 |
class CustomDialoGPT(nn.Module):
|
10 |
+
def __init__(self, vocab_size, n_embd=768, n_head=8, n_layer=8): # <---- FORCE n_embd, n_head, n_layer to match your model
|
11 |
super().__init__()
|
12 |
|
13 |
config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium",
|
|
|
39 |
|
40 |
# Load tokenizer
|
41 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
42 |
+
vocab_size = len(tokenizer) # <---- Define vocab_size AFTER loading tokenizer
|
43 |
|
44 |
# Initialize model with fixed parameters to match checkpoint
|
45 |
+
n_embd=768
|
46 |
+
n_head=8
|
47 |
+
n_layer=8
|
48 |
+
model = CustomDialoGPT(vocab_size, n_embd, n_head, n_layer).to(device).eval()
|
|
|
49 |
|
50 |
# Download and load model weights
|
51 |
try:
|
|
|
53 |
checkpoint = torch.load(pth_filepath, map_location=device)
|
54 |
|
55 |
# Handle different checkpoint saving formats if needed.
|
|
|
56 |
if 'model_state_dict' in checkpoint:
|
57 |
model.load_state_dict(checkpoint['model_state_dict'])
|
58 |
elif 'state_dict' in checkpoint:
|
59 |
model.load_state_dict(checkpoint['state_dict'])
|
60 |
else:
|
|
|
61 |
model.load_state_dict(checkpoint)
|
62 |
|
63 |
print(f"Successfully loaded model weights from {model_repo}/{model_filename}")
|
|
|
69 |
model.to(device)
|
70 |
model.eval() # Set model to evaluation mode
|
71 |
|
72 |
+
def chat_with_model(user_input): # Removed history parameter for gr.Text() output
|
73 |
+
"""Chatbot function to interact with the loaded model - DYNAMIC RESPONSE."""
|
74 |
+
input_ids = tokenizer.encode(user_input, return_tensors='pt').to(device)
|
|
|
|
|
|
|
75 |
|
76 |
with torch.no_grad():
|
77 |
+
output = model.transformer.generate(
|
78 |
+
inputs=input_ids,
|
79 |
+
max_length=100,
|
80 |
pad_token_id=tokenizer.eos_token_id,
|
81 |
temperature=0.7,
|
82 |
+
top_p=0.9,
|
83 |
+
do_sample=True
|
84 |
)
|
85 |
|
86 |
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
87 |
+
bot_response = response # No need to split for gr.Text()
|
88 |
|
89 |
+
print("--- chat_with_model Output ---") # Debugging print
|
90 |
+
print("user_input:", user_input) # Debugging print
|
91 |
+
print("bot_response:", bot_response) # Debugging print
|
92 |
+
print("--- End of chat_with_model Output ---") # Debugging print
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
+
return bot_response # Just return bot_response for gr.Text()
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
iface = gr.Interface( # Changed from gr.ChatInterface to gr.Interface
|
98 |
fn=chat_with_model,
|
99 |
inputs=gr.Textbox(placeholder="Type your message here..."), # Explicitly define inputs as gr.Textbox
|
100 |
+
outputs=gr.Text(), # Changed outputs to gr.Text()
|
101 |
+
title="ElapticAI-1a Chatbot - TESTING MODEL RESPONSE", # Updated title
|
102 |
+
description="Simple chatbot interface for ElapticAI-1a model - TESTING MODEL RESPONSE" # Updated description
|
|
|
|
|
|
|
|
|
|
|
103 |
)
|
104 |
|
105 |
if __name__ == "__main__":
|
106 |
+
iface.launch()
|