ysdede commited on
Commit
012c0fa
·
1 Parent(s): b327205

Refactor app.py to integrate Turkish LLaMA 8B model

Browse files

- Updated model loading to use bfloat16 precision.
- Added terminators for proper message handling.
- Included `add_generation_prompt=True` in the tokenizer's chat template application.
- Enhanced the system prompt with a default instruction for the AI assistant.
- Updated examples to reflect relevant Turkish language queries.

Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -40,10 +40,15 @@ if torch.cuda.is_available():
40
  model = AutoModelForCausalLM.from_pretrained(
41
  model_id,
42
  device_map="auto",
43
- torch_dtype=torch.float16,
44
  )
45
  tokenizer = AutoTokenizer.from_pretrained(model_id)
46
  tokenizer.use_default_system_prompt = False
 
 
 
 
 
47
 
48
 
49
  @spaces.GPU
@@ -63,7 +68,11 @@ def generate(
63
  conversation += chat_history
64
  conversation.append({"role": "user", "content": message})
65
 
66
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
 
 
 
67
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
68
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
69
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
@@ -80,6 +89,7 @@ def generate(
80
  temperature=temperature,
81
  num_beams=1,
82
  repetition_penalty=repetition_penalty,
 
83
  )
84
  t = Thread(target=model.generate, kwargs=generate_kwargs)
85
  t.start()
@@ -93,7 +103,11 @@ def generate(
93
  chat_interface = gr.ChatInterface(
94
  fn=generate,
95
  additional_inputs=[
96
- gr.Textbox(label="System prompt", lines=6),
 
 
 
 
97
  gr.Slider(
98
  label="Max new tokens",
99
  minimum=1,
@@ -133,10 +147,7 @@ chat_interface = gr.ChatInterface(
133
  stop_btn=None,
134
  examples=[
135
  ["Merhaba! Nasılsın?"],
136
- ["Python programlama dilini kısaca açıklayabilir misin?"],
137
- ["Külkedisi masalının özetini bir cümlede anlat."],
138
  ["Yapay zeka alanında açık kaynak kodun faydaları nelerdir?"],
139
- ["İstanbul'un en ünlü turistik yerlerini sıralar mısın?"],
140
  ],
141
  cache_examples=False,
142
  type="messages",
 
40
  model = AutoModelForCausalLM.from_pretrained(
41
  model_id,
42
  device_map="auto",
43
+ torch_dtype=torch.bfloat16,
44
  )
45
  tokenizer = AutoTokenizer.from_pretrained(model_id)
46
  tokenizer.use_default_system_prompt = False
47
+
48
+ TERMINATORS = [
49
+ tokenizer.eos_token_id,
50
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
51
+ ]
52
 
53
 
54
  @spaces.GPU
 
68
  conversation += chat_history
69
  conversation.append({"role": "user", "content": message})
70
 
71
+ input_ids = tokenizer.apply_chat_template(
72
+ conversation,
73
+ add_generation_prompt=True,
74
+ return_tensors="pt"
75
+ )
76
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
77
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
78
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
89
  temperature=temperature,
90
  num_beams=1,
91
  repetition_penalty=repetition_penalty,
92
+ eos_token_id=TERMINATORS,
93
  )
94
  t = Thread(target=model.generate, kwargs=generate_kwargs)
95
  t.start()
 
103
  chat_interface = gr.ChatInterface(
104
  fn=generate,
105
  additional_inputs=[
106
+ gr.Textbox(
107
+ label="System prompt",
108
+ lines=6,
109
+ value="Sen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir.",
110
+ ),
111
  gr.Slider(
112
  label="Max new tokens",
113
  minimum=1,
 
147
  stop_btn=None,
148
  examples=[
149
  ["Merhaba! Nasılsın?"],
 
 
150
  ["Yapay zeka alanında açık kaynak kodun faydaları nelerdir?"],
 
151
  ],
152
  cache_examples=False,
153
  type="messages",