elapt1c commited on
Commit
c35f66d
·
verified ·
1 Parent(s): 5bc59eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -49
app.py CHANGED
@@ -7,7 +7,7 @@ import torch.nn as nn
7
 
8
  # ----- Model Definition -----
9
  class CustomDialoGPT(nn.Module):
10
- def __init__(self, vocab_size, n_embd=768, n_head=8, n_layer=8): # <---- FORCE n_embd, n_head, n_layer to match DialoGPT-medium
11
  super().__init__()
12
 
13
  config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium",
@@ -39,14 +39,13 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
39
 
40
  # Load tokenizer
41
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
42
- vocab_size = len(tokenizer)
43
 
44
  # Initialize model with fixed parameters to match checkpoint
45
- n_embd=768 # <---- FORCE n_embd to 768
46
- n_head=8 # <---- FORCE n_head to 12
47
- n_layer=8 # <---- FORCE n_layer to 12
48
- model = CustomDialoGPT(vocab_size, n_embd, n_head, n_layer)
49
-
50
 
51
  # Download and load model weights
52
  try:
@@ -54,13 +53,11 @@ try:
54
  checkpoint = torch.load(pth_filepath, map_location=device)
55
 
56
  # Handle different checkpoint saving formats if needed.
57
- # If your checkpoint is just the state_dict, load it directly.
58
  if 'model_state_dict' in checkpoint:
59
  model.load_state_dict(checkpoint['model_state_dict'])
60
  elif 'state_dict' in checkpoint:
61
  model.load_state_dict(checkpoint['state_dict'])
62
  else:
63
- # Assume checkpoint is just the raw state_dict
64
  model.load_state_dict(checkpoint)
65
 
66
  print(f"Successfully loaded model weights from {model_repo}/{model_filename}")
@@ -72,60 +69,38 @@ except Exception as e:
72
  model.to(device)
73
  model.eval() # Set model to evaluation mode
74
 
75
- def chat_with_model(user_input, history=[]):
76
- """Chatbot function to interact with the loaded model."""
77
- history_transformer_format = history_to_transformer_format(history)
78
- input_text = tokenizer.eos_token.join(history_transformer_format + [user_input])
79
-
80
- input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
81
 
82
  with torch.no_grad():
83
- output = model.transformer.generate( # Use model.transformer.generate here
84
- inputs=input_ids, # Use inputs instead of input_ids
85
- max_length=1000, # Adjust as needed
86
  pad_token_id=tokenizer.eos_token_id,
87
  temperature=0.7,
88
- top_p=0.9
 
89
  )
90
 
91
  response = tokenizer.decode(output[0], skip_special_tokens=True)
 
92
 
93
- # Extract only the bot's last response, assuming it's after the last user input.
94
- # This is a simple heuristic and might need adjustments based on training data format.
95
- split_response = response.split(tokenizer.eos_token)
96
- bot_response = split_response[-1].strip()
97
-
98
- # Explicitly format history as list of tuples:
99
- history.append((user_input, bot_response))
100
-
101
- # Reformat history for Gradio Chatbot - Ensure tuples within a list
102
- chatbot_history = []
103
- for turn in history:
104
- chatbot_history.append(turn) # Each turn is already a tuple (user_msg, bot_msg)
105
-
106
- return bot_response, chatbot_history # Return chatbot_history for Gradio
107
 
108
- def history_to_transformer_format(history):
109
- """Convert gradio history to a list of strings for transformer input."""
110
- history_formatted = []
111
- for user_msg, bot_msg in history:
112
- history_formatted.append(user_msg)
113
- history_formatted.append(bot_msg)
114
- return history_formatted
115
 
116
 
117
  iface = gr.Interface( # Changed from gr.ChatInterface to gr.Interface
118
  fn=chat_with_model,
119
  inputs=gr.Textbox(placeholder="Type your message here..."), # Explicitly define inputs as gr.Textbox
120
- outputs=gr.Chatbot(), # Explicitly define outputs as gr.Chatbot
121
- title="ElapticAI-1a Chatbot",
122
- description="Simple chatbot interface for ElapticAI-1a model. Talk to the model and see its responses!",
123
- examples=[ # Corrected examples format
124
- ["Hello", "Hi there!"], # Example 1: [user_input, bot_response]
125
- ["How are you?", "I am doing well, thank you."], # Example 2
126
- ["Tell me a joke", "Why don't scientists trust atoms? Because they make up everything! 😄"] # Example 3
127
- ]
128
  )
129
 
130
  if __name__ == "__main__":
131
- iface.launch()
 
7
 
8
  # ----- Model Definition -----
9
  class CustomDialoGPT(nn.Module):
10
+ def __init__(self, vocab_size, n_embd=768, n_head=8, n_layer=8): # <---- FORCE n_embd, n_head, n_layer to match your model
11
  super().__init__()
12
 
13
  config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium",
 
39
 
40
  # Load tokenizer
41
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
42
+ vocab_size = len(tokenizer) # <---- Define vocab_size AFTER loading tokenizer
43
 
44
  # Initialize model with fixed parameters to match checkpoint
45
+ n_embd=768
46
+ n_head=8
47
+ n_layer=8
48
+ model = CustomDialoGPT(vocab_size, n_embd, n_head, n_layer).to(device).eval()
 
49
 
50
  # Download and load model weights
51
  try:
 
53
  checkpoint = torch.load(pth_filepath, map_location=device)
54
 
55
  # Handle different checkpoint saving formats if needed.
 
56
  if 'model_state_dict' in checkpoint:
57
  model.load_state_dict(checkpoint['model_state_dict'])
58
  elif 'state_dict' in checkpoint:
59
  model.load_state_dict(checkpoint['state_dict'])
60
  else:
 
61
  model.load_state_dict(checkpoint)
62
 
63
  print(f"Successfully loaded model weights from {model_repo}/{model_filename}")
 
69
  model.to(device)
70
  model.eval() # Set model to evaluation mode
71
 
72
+ def chat_with_model(user_input): # Removed history parameter for gr.Text() output
73
+ """Chatbot function to interact with the loaded model - DYNAMIC RESPONSE."""
74
+ input_ids = tokenizer.encode(user_input, return_tensors='pt').to(device)
 
 
 
75
 
76
  with torch.no_grad():
77
+ output = model.transformer.generate(
78
+ inputs=input_ids,
79
+ max_length=100,
80
  pad_token_id=tokenizer.eos_token_id,
81
  temperature=0.7,
82
+ top_p=0.9,
83
+ do_sample=True
84
  )
85
 
86
  response = tokenizer.decode(output[0], skip_special_tokens=True)
87
+ bot_response = response # No need to split for gr.Text()
88
 
89
+ print("--- chat_with_model Output ---") # Debugging print
90
+ print("user_input:", user_input) # Debugging print
91
+ print("bot_response:", bot_response) # Debugging print
92
+ print("--- End of chat_with_model Output ---") # Debugging print
 
 
 
 
 
 
 
 
 
 
93
 
94
+ return bot_response # Just return bot_response for gr.Text()
 
 
 
 
 
 
95
 
96
 
97
  iface = gr.Interface( # Changed from gr.ChatInterface to gr.Interface
98
  fn=chat_with_model,
99
  inputs=gr.Textbox(placeholder="Type your message here..."), # Explicitly define inputs as gr.Textbox
100
+ outputs=gr.Text(), # Changed outputs to gr.Text()
101
+ title="ElapticAI-1a Chatbot - TESTING MODEL RESPONSE", # Updated title
102
+ description="Simple chatbot interface for ElapticAI-1a model - TESTING MODEL RESPONSE" # Updated description
 
 
 
 
 
103
  )
104
 
105
  if __name__ == "__main__":
106
+ iface.launch()