rasyosef commited on
Commit
d15c04c
ยท
verified ยท
1 Parent(s): cb638de

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, pipeline
5
+ from threading import Thread
6
+
7
+ model_id = "rasyosef/Llama-3.2-400M-Amharic-Instruct-Poems-Stories-Wikipedia"
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_id,
12
+ torch_dtype=torch.float32,
13
+ device_map="cuda" if torch.cuda.is_available() else "cpu"
14
+ )
15
+
16
+ llama3_am = pipeline(
17
+ "text-generation",
18
+ model=model,
19
+ tokenizer=tokenizer,
20
+ eos_token_id=tokenizer.eos_token_id,
21
+ device_map="cuda" if torch.cuda.is_available() else "cpu"
22
+ )
23
+
24
+ # Function that accepts a prompt and generates text
25
+ def generate(message, chat_history, max_new_tokens=64):
26
+
27
+ history = []
28
+
29
+ for sent, received in chat_history:
30
+ history.append({"role": "user", "content": sent})
31
+ history.append({"role": "assistant", "content": received})
32
+
33
+ history.append({"role": "user", "content": message})
34
+
35
+ if len(tokenizer.apply_chat_template(history)) > 512:
36
+ yield "chat history is too long"
37
+ else:
38
+ # Streamer
39
+ streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=300.0)
40
+ thread = Thread(
41
+ target=llama3_am,
42
+ kwargs={
43
+ "text_inputs":history,
44
+ "max_new_tokens":max_new_tokens,
45
+ "repetition_penalty":1.1,
46
+ "streamer":streamer
47
+ }
48
+ )
49
+ thread.start()
50
+
51
+ generated_text = ""
52
+ for word in streamer:
53
+ generated_text += word
54
+ response = generated_text.strip()
55
+
56
+ yield response
57
+
58
+ # Chat interface with gradio
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("""
61
+ # Llama 3.2 400M Amharic Chatbot Demo
62
+ """)
63
+
64
+ tokens_slider = gr.Slider(8, 256, value=64, label="Maximum new tokens", info="A larger `max_new_tokens` parameter value gives you longer text responses but at the cost of a slower response time.")
65
+
66
+ chatbot = gr.ChatInterface(
67
+ chatbot=gr.Chatbot(height=400),
68
+ fn=generate,
69
+ additional_inputs=[tokens_slider],
70
+ stop_btn=None,
71
+ examples=[
72
+ ["แˆฐแˆ‹แˆ"],
73
+ ["แˆฐแˆ‹แˆแฃ แŠฅแŠ•แ‹ดแ‰ต แŠแˆ…?"],
74
+ ["แŠ แŠ•แ‰ฐ แˆ›แŠแˆ…?"],
75
+ ["แŒแŒฅแˆ แƒแแˆแŠ"],
76
+ ["แˆตแˆˆ แ‹ญแ‰…แˆญแ‰ณ แŒแŒฅแˆ แŒปแแˆแŠ"],
77
+ ["แŠ แŠ•แ‹ต แ‰ฐแˆจแ‰ต แŠ แŒซแ‹แ‰ฐแŠ"],
78
+ ["แˆตแˆˆ แŒ…แ‰ฅแŠ“ แŠ แŠ•แ‰ แˆณ แ‰ฐแˆจแ‰ต แŠ•แŒˆแˆจแŠ"],
79
+ ["แ‰€แˆแ‹ต แŠ•แŒˆแˆจแŠ"],
80
+ ["แˆตแˆˆ แˆตแˆซ แŠ แŒฅแŠแ‰ต แŠ แŠ•แ‹ต แ‰€แˆแ‹ต แŠ•แŒˆแˆจแŠ"],
81
+ ["แ‹ณแŒแˆ›แ‹Š แ‰ดแ‹Žแ‹ตแˆฎแˆต แˆ›แŠ• แŠแ‹?"],
82
+ ["แ‹ณแŒแˆ›แ‹Š แˆแŠ•แˆŠแŠญ แˆ›แŠ• แŠแ‹?"],
83
+ ["แˆตแˆˆ แŠ แ‹ฒแˆต แŠ แ‰ แ‰ฃ แ‹ฉแŠ’แ‰จแˆญแˆตแ‰ฒ แŒฅแ‰‚แ‰ต แŠฅแ‹แŠแ‰ณแ‹Žแ‰ฝแŠ• แŠ แŒซแ‹แ‰ฐแŠ"],
84
+ ["แˆตแˆˆ แŒƒแ“แŠ• แŒฅแ‰‚แ‰ต แŠฅแ‹แŠแ‰ณแ‹Žแ‰ฝแŠ• แŠ•แŒˆแˆจแŠ"],
85
+ ["แˆตแˆˆ แˆ›แ‹ญแŠญแˆฎแˆถแแ‰ต แŒฅแ‰‚แ‰ต แŠฅแ‹แŠแ‰ณแ‹Žแ‰ฝแŠ• แŠ•แŒˆแˆจแŠ"],
86
+ ["แŒ‰แŒแˆ แˆแŠ•แ‹ตแŠ• แŠแ‹?"],
87
+ ["แ‰ขแ‰ตแŠฎแ‹ญแŠ• แˆแŠ•แ‹ตแŠ• แŠแ‹?"],
88
+ ]
89
+ )
90
+
91
+ demo.queue().launch(debug=True)