BarBar288 commited on
Commit
cbaa4b8
·
verified ·
1 Parent(s): d4edab2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # Define a dictionary of model names and their corresponding Hugging Face model IDs
6
+ models = {
7
+ "GPT-Neo-125M": "EleutherAI/gpt-neo-125M",
8
+ "GPT-J-6B": "EleutherAI/gpt-j-6B",
9
+ "GPT-NeoX-20B": "EleutherAI/gpt-neox-20b",
10
+ "GPT-3.5-Turbo": "gpt2", # Placeholder for illustrative purposes
11
+ }
12
+
13
+ # Initialize tokenizers and models
14
+ tokenizers = {}
15
+ models_loaded = {}
16
+
17
+ for model_name, model_id in models.items():
18
+ tokenizers[model_name] = AutoTokenizer.from_pretrained(model_id)
19
+ models_loaded[model_name] = AutoModelForCausalLM.from_pretrained(model_id)
20
+
21
+ def chat(model_name, user_input, history=[]):
22
+ tokenizer = tokenizers[model_name]
23
+ model = models_loaded[model_name]
24
+
25
+ # Encode the input
26
+ input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
27
+
28
+ # Generate a response
29
+ with torch.no_grad():
30
+ output = model.generate(input_ids, max_length=150, pad_token_id=tokenizer.eos_token_id)
31
+
32
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
33
+
34
+ # Clean up the response to remove the user input part
35
+ response = response[len(user_input):].strip()
36
+
37
+ # Append to chat history
38
+ history.append((user_input, response))
39
+
40
+ return history, history
41
+
42
+ # Define the Gradio interface
43
+ with gr.Blocks() as demo:
44
+ gr.Markdown("## Chat with Different Models")
45
+
46
+ model_choice = gr.Dropdown(list(models.keys()), label="Choose a Model")
47
+ chatbot = gr.Chatbot(label="Chat")
48
+ message = gr.Textbox(label="Message")
49
+ submit = gr.Button("Submit")
50
+
51
+ submit.click(chat, inputs=[model_choice, message, chatbot], outputs=[chatbot, chatbot])
52
+ message.submit(chat, inputs=[model_choice, message, chatbot], outputs=[chatbot, chatbot])
53
+
54
+ # Launch the demo
55
+ demo.launch()