davideuler commited on
Commit
9482433
·
1 Parent(s): eaddcef

initial version

Browse files
Files changed (5) hide show
  1. .gitignore +11 -0
  2. .python-version +1 -0
  3. main.py +96 -0
  4. pyproject.toml +13 -0
  5. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+ .gradio
9
+
10
+ # Virtual environments
11
+ .venv
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
main.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, T5ForConditionalGeneration, T5Tokenizer
3
+
4
+ class MultiModelChat:
5
+ def __init__(self):
6
+ self.models = {}
7
+
8
+ def ensure_model_loaded(self, model_name):
9
+ """Lazy load a model only when needed"""
10
+ if model_name not in self.models:
11
+ print(f"Loading {model_name} model...")
12
+
13
+ if model_name == 'SmolLM2':
14
+ self.models['SmolLM2'] = {
15
+ 'tokenizer': AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct"),
16
+ 'model': AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
17
+ }
18
+ elif model_name == 'FLAN-T5':
19
+ self.models['FLAN-T5'] = {
20
+ 'tokenizer': T5Tokenizer.from_pretrained("google/flan-t5-small"),
21
+ 'model': T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
22
+ }
23
+
24
+ # Set pad token for the newly loaded model
25
+ if self.models[model_name]['tokenizer'].pad_token is None:
26
+ self.models[model_name]['tokenizer'].pad_token = self.models[model_name]['tokenizer'].eos_token
27
+
28
+ print(f"{model_name} model loaded successfully!")
29
+
30
+ def chat(self, message, history, model_choice):
31
+ if model_choice == "SmolLM2":
32
+ return self.chat_smol(message, history)
33
+ elif model_choice == "FLAN-T5":
34
+ return self.chat_flan(message, history)
35
+
36
+ def chat_smol(self, message, history):
37
+ self.ensure_model_loaded('SmolLM2')
38
+
39
+ tokenizer = self.models['SmolLM2']['tokenizer']
40
+ model = self.models['SmolLM2']['model']
41
+
42
+ inputs = tokenizer(f"User: {message}\nAssistant:", return_tensors="pt")
43
+ outputs = model.generate(
44
+ inputs.input_ids,
45
+ max_new_tokens=80,
46
+ temperature=0.7,
47
+ do_sample=True,
48
+ pad_token_id=tokenizer.eos_token_id
49
+ )
50
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+ return response.split("Assistant:")[-1].strip()
52
+
53
+ def chat_flan(self, message, history):
54
+ self.ensure_model_loaded('FLAN-T5')
55
+
56
+ tokenizer = self.models['FLAN-T5']['tokenizer']
57
+ model = self.models['FLAN-T5']['model']
58
+
59
+ inputs = tokenizer(f"Answer the question: {message}", return_tensors="pt")
60
+ outputs = model.generate(inputs.input_ids, max_length=100)
61
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+
63
+ chat_app = MultiModelChat()
64
+
65
+ def respond(message, history, model_choice):
66
+ return chat_app.chat(message, history, model_choice)
67
+
68
+ with gr.Blocks(theme="soft") as demo:
69
+ gr.Markdown("# Multi-Model Tiny Chatbot")
70
+
71
+ with gr.Row():
72
+ model_dropdown = gr.Dropdown(
73
+ choices=["SmolLM2", "FLAN-T5"],
74
+ value="SmolLM2",
75
+ label="Select Model"
76
+ )
77
+
78
+ chatbot = gr.Chatbot(height=400)
79
+ msg = gr.Textbox(label="Message", placeholder="Type your message here...")
80
+ clear = gr.Button("Clear")
81
+
82
+ def user_message(message, history):
83
+ return "", history + [[message, None]]
84
+
85
+ def bot_message(history, model_choice):
86
+ user_msg = history[-1][0]
87
+ bot_response = chat_app.chat(user_msg, history[:-1], model_choice)
88
+ history[-1][1] = bot_response
89
+ return history
90
+
91
+ msg.submit(user_message, [msg, chatbot], [msg, chatbot]).then(
92
+ bot_message, [chatbot, model_dropdown], chatbot
93
+ )
94
+ clear.click(lambda: None, None, chatbot, queue=False)
95
+
96
+ demo.launch()
pyproject.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "small-model-chat"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "gradio>=5.31.0",
9
+ "huggingface-hub[hf-xet]>=0.31.4",
10
+ "sentencepiece>=0.2.0",
11
+ "torch>=2.7.0",
12
+ "transformers>=4.52.3",
13
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff