nafisneehal commited on
Commit
95ce3bb
·
verified ·
1 Parent(s): ece978e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -57
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import os
 
5
 
6
  # File to store model links
7
  MODEL_FILE = "model_links.txt"
@@ -11,8 +13,8 @@ def load_model_links():
11
  # if not os.path.exists(MODEL_FILE):
12
  # # Create default file with some example models
13
  # with open(MODEL_FILE, "w") as f:
14
- # f.write("facebook/opt-125m\n")
15
- # f.write("facebook/opt-350m\n")
16
 
17
  with open(MODEL_FILE, "r") as f:
18
  return [line.strip() for line in f.readlines() if line.strip()]
@@ -22,6 +24,7 @@ class ModelManager:
22
  self.current_model = None
23
  self.current_tokenizer = None
24
  self.current_model_name = None
 
25
 
26
  def load_model(self, model_name):
27
  """Load model and free previous model's memory"""
@@ -30,71 +33,142 @@ class ModelManager:
30
  del self.current_tokenizer
31
  torch.cuda.empty_cache()
32
 
33
- self.current_tokenizer = AutoTokenizer.from_pretrained(model_name)
34
- self.current_model = AutoModelForCausalLM.from_pretrained(model_name)
35
- self.current_model_name = model_name
36
- return f"Loaded model: {model_name}"
37
-
38
- def generate_response(self, system_message, user_message):
39
- """Generate response from the model"""
40
- if self.current_model is None:
41
- return "Please select and load a model first."
42
-
43
- # Combine system and user messages
44
- prompt = f"{system_message}\n\nUser: {user_message}\n\nAssistant:"
45
-
46
- # Generate response
47
- inputs = self.current_tokenizer(prompt, return_tensors="pt", padding=True)
48
- outputs = self.current_model.generate(
49
- inputs.input_ids,
50
- max_length=200,
51
- num_return_sequences=1,
52
- temperature=0.7,
53
- pad_token_id=self.current_tokenizer.eos_token_id
54
- )
55
-
56
- response = self.current_tokenizer.decode(outputs[0], skip_special_tokens=True)
57
- # Extract only the assistant's response
58
- response = response.split("Assistant:")[-1].strip()
59
- return response
60
 
61
  # Initialize model manager
62
  model_manager = ModelManager()
63
 
64
- # Create Gradio interface
65
- with gr.Blocks() as demo:
66
- gr.Markdown("# Chat Interface with Model Selection")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  with gr.Row():
69
- with gr.Column(scale=1):
70
- # Input components
71
  model_dropdown = gr.Dropdown(
72
  choices=load_model_links(),
73
  label="Select Model",
74
  info="Choose a model from the list"
75
  )
76
  load_button = gr.Button("Load Selected Model")
77
- system_msg = gr.Textbox(
78
- label="System Message",
79
- placeholder="Enter system message here...",
 
 
 
80
  lines=3
81
  )
82
- user_msg = gr.Textbox(
83
- label="User Message",
84
- placeholder="Enter your message here...",
85
  lines=3
86
  )
87
- submit_button = gr.Button("Generate Response")
88
-
89
- with gr.Column(scale=1):
90
- # Output components
91
- model_status = gr.Textbox(label="Model Status")
92
- chat_output = gr.Textbox(
93
- label="Assistant Response",
94
- lines=10,
95
- interactive=False
96
  )
97
-
98
  # Event handlers
99
  load_button.click(
100
  fn=model_manager.load_model,
@@ -102,12 +176,11 @@ with gr.Blocks() as demo:
102
  outputs=[model_status]
103
  )
104
 
105
- submit_button.click(
106
- fn=model_manager.generate_response,
107
- inputs=[system_msg, user_msg],
108
- outputs=[chat_output]
109
  )
110
 
111
  # Launch the app
112
- if __name__ == "__main__":
113
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList
4
+ import spaces
5
  import os
6
+ import json
7
 
8
  # File to store model links
9
  MODEL_FILE = "model_links.txt"
 
13
  # if not os.path.exists(MODEL_FILE):
14
  # # Create default file with some example models
15
  # with open(MODEL_FILE, "w") as f:
16
+ # f.write("meta-llama/Llama-2-7b-chat-hf\n")
17
+ # f.write("tiiuae/falcon-7b-instruct\n")
18
 
19
  with open(MODEL_FILE, "r") as f:
20
  return [line.strip() for line in f.readlines() if line.strip()]
 
24
  self.current_model = None
25
  self.current_tokenizer = None
26
  self.current_model_name = None
27
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
  def load_model(self, model_name):
30
  """Load model and free previous model's memory"""
 
33
  del self.current_tokenizer
34
  torch.cuda.empty_cache()
35
 
36
+ try:
37
+ self.current_tokenizer = AutoTokenizer.from_pretrained(model_name)
38
+ self.current_model = AutoModelForCausalLM.from_pretrained(
39
+ model_name,
40
+ load_in_4bit=True,
41
+ device_map="auto"
42
+ )
43
+ self.current_model_name = model_name
44
+ return f"Successfully loaded model: {model_name}"
45
+ except Exception as e:
46
+ return f"Error loading model: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # Initialize model manager
49
  model_manager = ModelManager()
50
 
51
+ # Default system message for JSON output
52
+ default_system_message = """You are a helpful AI assistant. You must ALWAYS return your response in valid JSON format.
53
+ Each response should be formatted as follows:
54
+
55
+ {
56
+ "response": {
57
+ "main_answer": "Your primary response here",
58
+ "additional_details": "Any additional information or context",
59
+ "confidence": 0.0 to 1.0,
60
+ "tags": ["relevant", "tags", "here"]
61
+ },
62
+ "metadata": {
63
+ "response_type": "type of response",
64
+ "source": "basis of response if applicable"
65
+ }
66
+ }
67
+
68
+ Ensure EVERY response strictly follows this JSON structure."""
69
+
70
+ @spaces.GPU
71
+ def generate_response(model_name, system_instruction, user_input):
72
+ """Generate response with GPU support and JSON formatting"""
73
+ if model_manager.current_model_name != model_name:
74
+ return json.dumps({"error": "Please load the model first using the 'Load Selected Model' button."}, indent=2)
75
 
76
+ if model_manager.current_model is None:
77
+ return json.dumps({"error": "No model loaded. Please load a model first."}, indent=2)
78
+
79
+ # Prepare the prompt with explicit JSON formatting
80
+ prompt = f"""### Instruction:
81
+ {system_instruction}
82
+ Remember to ALWAYS format your response as valid JSON.
83
+
84
+ ### Input:
85
+ {user_input}
86
+
87
+ ### Response:
88
+ {{""" # Note the opening curly brace to hint JSON response
89
+
90
+ inputs = model_manager.current_tokenizer([prompt], return_tensors="pt").to(model_manager.device)
91
+
92
+ # Generation configuration optimized for JSON output
93
+ meta_config = {
94
+ "do_sample": False,
95
+ "temperature": 0.0,
96
+ "max_new_tokens": 512,
97
+ "repetition_penalty": 1.1,
98
+ "use_cache": True,
99
+ "pad_token_id": model_manager.current_tokenizer.eos_token_id,
100
+ "eos_token_id": model_manager.current_tokenizer.eos_token_id
101
+ }
102
+ generation_config = GenerationConfig(**meta_config)
103
+
104
+ # Generate response
105
+ try:
106
+ with torch.no_grad():
107
+ outputs = model_manager.current_model.generate(
108
+ **inputs,
109
+ generation_config=generation_config
110
+ )
111
+ decoded_output = model_manager.current_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
112
+ assistant_response = decoded_output.split("### Response:")[-1].strip()
113
+
114
+ # Clean up and validate JSON
115
+ try:
116
+ # Find the last complete JSON object
117
+ last_brace = assistant_response.rindex('}')
118
+ assistant_response = assistant_response[:last_brace + 1]
119
+
120
+ # Parse and re-format JSON
121
+ json_response = json.loads(assistant_response)
122
+ return json.dumps(json_response, indent=2)
123
+ except (json.JSONDecodeError, ValueError):
124
+ return json.dumps({
125
+ "error": "Failed to generate valid JSON",
126
+ "raw_response": assistant_response
127
+ }, indent=2)
128
+
129
+ except Exception as e:
130
+ return json.dumps({
131
+ "error": f"Error generating response: {str(e)}",
132
+ "details": "An unexpected error occurred during generation"
133
+ }, indent=2)
134
+
135
+ # Gradio interface setup
136
+ with gr.Blocks() as demo:
137
+ gr.Markdown("# Chat Interface with Model Selection (JSON Output)")
138
+
139
  with gr.Row():
140
+ # Left column for inputs
141
+ with gr.Column():
142
  model_dropdown = gr.Dropdown(
143
  choices=load_model_links(),
144
  label="Select Model",
145
  info="Choose a model from the list"
146
  )
147
  load_button = gr.Button("Load Selected Model")
148
+ model_status = gr.Textbox(label="Model Status")
149
+
150
+ system_instruction = gr.Textbox(
151
+ value=default_system_message,
152
+ placeholder="Enter system instruction here...",
153
+ label="System Instruction",
154
  lines=3
155
  )
156
+ user_input = gr.Textbox(
157
+ placeholder="Type your message here...",
158
+ label="Your Message",
159
  lines=3
160
  )
161
+ submit_btn = gr.Button("Submit")
162
+
163
+ # Right column for bot response
164
+ with gr.Column():
165
+ response_display = gr.Textbox(
166
+ label="Bot Response (JSON)",
167
+ interactive=False,
168
+ placeholder="Response will appear here in JSON format.",
169
+ lines=10
170
  )
171
+
172
  # Event handlers
173
  load_button.click(
174
  fn=model_manager.load_model,
 
176
  outputs=[model_status]
177
  )
178
 
179
+ submit_btn.click(
180
+ fn=generate_response,
181
+ inputs=[model_dropdown, system_instruction, user_input],
182
+ outputs=[response_display]
183
  )
184
 
185
  # Launch the app
186
+ demo.launch()