Theresa Hoesl commited on
Commit
9c587b3
·
1 Parent(s): 007fe97

corrected respond function

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from peft import AutoPeftModelForCausalLM
@@ -80,5 +81,88 @@ demo = gr.ChatInterface(
80
  ],
81
  )
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if __name__ == "__main__":
84
  demo.launch()
 
1
+ '''
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from peft import AutoPeftModelForCausalLM
 
81
  ],
82
  )
83
 
84
+ if __name__ == "__main__":
85
+ demo.launch()
86
+ '''
87
+
88
+ import gradio as gr
89
+ from transformers import AutoTokenizer
90
+ from peft import AutoPeftModelForCausalLM
91
+ import torch
92
+
93
+ # Load the model and tokenizer
94
+ def load_model():
95
+ lora_model_name = "sreyanghosh/lora_model" # Replace with your LoRA model path
96
+ model = AutoPeftModelForCausalLM.from_pretrained(
97
+ lora_model_name,
98
+ load_in_4bit=False,
99
+ )
100
+ tokenizer = AutoTokenizer.from_pretrained(lora_model_name)
101
+ if tokenizer.pad_token_id is None:
102
+ tokenizer.pad_token_id = tokenizer.eos_token_id
103
+ model.eval()
104
+ device = "cuda" if torch.cuda.is_available() else "cpu"
105
+ model = model.to(device)
106
+ return tokenizer, model
107
+
108
+ tokenizer, model = load_model()
109
+
110
+ # Define the respond function
111
+ def respond(
112
+ message,
113
+ history: list[tuple[str, str]],
114
+ system_message,
115
+ max_tokens,
116
+ temperature,
117
+ top_p,
118
+ ):
119
+ # Prepare the conversation history
120
+ messages = [{"role": "system", "content": system_message}]
121
+ for user_input, bot_response in history:
122
+ if user_input:
123
+ messages.append({"role": "user", "content": user_input})
124
+ if bot_response:
125
+ messages.append({"role": "assistant", "content": bot_response})
126
+ messages.append({"role": "user", "content": message})
127
+
128
+ # Format the input for the model
129
+ conversation_text = "\n".join(
130
+ f"{msg['role']}: {msg['content']}" for msg in messages
131
+ )
132
+ inputs = tokenizer(conversation_text, return_tensors="pt", truncation=True).to(model.device)
133
+
134
+ # Generate the model's response
135
+ outputs = model.generate(
136
+ inputs.input_ids,
137
+ max_length=len(inputs.input_ids[0]) + max_tokens,
138
+ temperature=temperature,
139
+ top_p=top_p,
140
+ pad_token_id=tokenizer.eos_token_id,
141
+ eos_token_id=tokenizer.eos_token_id,
142
+ )
143
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
144
+
145
+ # Extract the new response
146
+ new_response = response.split("assistant:")[-1].strip()
147
+ yield new_response
148
+
149
+ # Gradio app configuration
150
+ demo = gr.ChatInterface(
151
+ fn=respond,
152
+ chatbot="Assistant",
153
+ additional_inputs=[
154
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
155
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
156
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
157
+ gr.Slider(
158
+ minimum=0.1,
159
+ maximum=1.0,
160
+ value=0.95,
161
+ step=0.05,
162
+ label="Top-p (nucleus sampling)",
163
+ ),
164
+ ],
165
+ )
166
+
167
  if __name__ == "__main__":
168
  demo.launch()