HaveAI commited on
Commit
ef7ee72
·
verified ·
1 Parent(s): 7ec8e1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -2,19 +2,19 @@ import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
3
  import gradio as gr
4
 
5
- # Загрузка модели и токенизатора
6
- model_id = "microsoft/phi-2"
7
  tokenizer = AutoTokenizer.from_pretrained(model_id)
8
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
9
  model.to("cpu")
10
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
11
 
12
- # Функция генерации ответа
13
  def chat_fn(prompt):
14
- inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
15
- outputs = model.generate(**inputs, max_new_tokens=100, streamer=streamer)
16
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
17
- return response[len(prompt):].strip()
 
 
 
 
18
 
19
- # Интерфейс Gradio
20
- gr.Interface(fn=chat_fn, inputs="text", outputs="text", title="💬 Flare GPT — на Phi-2").launch()
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
3
  import gradio as gr
4
 
5
+ model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
8
  model.to("cpu")
9
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
10
 
 
11
  def chat_fn(prompt):
12
+ messages = [{"role": "user", "content": prompt}]
13
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
14
+ inputs = tokenizer(input_text, return_tensors="pt").to("cpu")
15
+ output = model.generate(**inputs, max_new_tokens=200, do_sample=True, temperature=0.7)
16
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
17
+ return response.split("user")[-1].strip()
18
+
19
+ gr.Interface(fn=chat_fn, inputs="text", outputs="text", title="💬 FlareGPT на TinyLlama").launch()
20