hackergeek98 commited on
Commit
f960061
·
verified ·
1 Parent(s): 31ede35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -11
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from peft import PeftModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -5,19 +6,41 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
5
  # Load tokenizer
6
  tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt")
7
 
8
- # Load base model on CPU
9
- base_model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-pt")
 
 
 
 
10
 
11
- # Load fine-tuned PEFT model
12
  model = PeftModel.from_pretrained(base_model, "hackergeek98/gemma-finetuned")
 
13
 
14
- # Ensure model runs on CPU
15
- model = model.to("cpu")
 
 
 
16
 
17
- # Test inference
18
- input_text = "Hello, how are you?"
19
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cpu")
 
 
 
 
20
 
21
- # Generate output
22
- output = model.generate(input_ids, max_length=50)
23
- print(tokenizer.decode(output[0], skip_special_tokens=True))
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import torch
3
  from peft import PeftModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
6
  # Load tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt")
8
 
9
+ # Load base model on CPU with optimizations
10
+ base_model = AutoModelForCausalLM.from_pretrained(
11
+ "google/gemma-3-1b-pt",
12
+ torch_dtype=torch.bfloat16, # Efficient memory usage
13
+ low_cpu_mem_usage=True
14
+ )
15
 
16
+ # Load fine-tuned model
17
  model = PeftModel.from_pretrained(base_model, "hackergeek98/gemma-finetuned")
18
+ model = model.to("cpu") # Ensure it runs on CPU
19
 
20
+ # Chatbot function
21
+ def chat(message, history=[]):
22
+ messages = [{"role": "user", "content": message}]
23
+
24
+ input_ids = tokenizer(message, return_tensors="pt").input_ids.to("cpu")
25
 
26
+ with torch.no_grad(): # Disable gradient calculations for efficiency
27
+ output_ids = model.generate(input_ids, max_length=100)
28
+
29
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
30
+
31
+ history.append((message, response)) # Store conversation history
32
+ return history, history
33
 
34
+ # Gradio UI
35
+ demo = gr.ChatInterface(
36
+ chat,
37
+ chatbot=gr.Chatbot(height=400),
38
+ additional_inputs=[
39
+ gr.Textbox(value="Welcome to the chatbot!", label="System message")
40
+ ],
41
+ title="Fine-Tuned Gemma Chatbot",
42
+ description="This chatbot is fine-tuned on Persian text using Gemma.",
43
+ )
44
+
45
+ if __name__ == "__main__":
46
+ demo.launch()