Samuel4677 commited on
Commit
ec80b9e
·
verified ·
1 Parent(s): 4c03c86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -21
app.py CHANGED
@@ -1,29 +1,35 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import gradio as gr
 
3
 
4
- # Załaduj model
5
- model_name = "google/mt5-small"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
 
9
- # Funkcja odpowiadająca na pytanie
10
- def answer_question(question):
11
- input_text = f"Pytanie: {question} Odpowiedź:"
12
- inputs = tokenizer.encode(input_text, return_tensors="pt")
13
- output = model.generate(
 
 
 
 
14
  inputs,
15
- max_new_tokens=60,
16
- do_sample=False,
17
- temperature=0.3,
18
- top_p=0.95
 
 
19
  )
20
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
21
 
22
  # Gradio UI
23
- gr.Interface(
24
- fn=answer_question,
25
- inputs=gr.Textbox(lines=2, placeholder="Zadaj pytanie..."),
26
- outputs=gr.Textbox(),
27
- title="🤖 Polski Chatbot AI",
28
- description="Zadaj pytanie po polsku, a chatbot udzieli sensownej odpowiedzi"
29
- ).launch()
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
  import gradio as gr
3
+ import torch
4
 
5
+ # Załaduj model i tokenizer
6
+ model_name = "radlab/polish-gpt2-small-v2"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
+ # Funkcja czatu
11
+ def chatbot(prompt, history=[]):
12
+ history_text = ""
13
+ for user, bot in history:
14
+ history_text += f"Użytkownik: {user}\nAI: {bot}\n"
15
+ history_text += f"Użytkownik: {prompt}\nAI:"
16
+
17
+ inputs = tokenizer.encode(history_text, return_tensors="pt", truncation=True, max_length=1024)
18
+ outputs = model.generate(
19
  inputs,
20
+ max_length=inputs.shape[1] + 80,
21
+ do_sample=True,
22
+ top_k=50,
23
+ top_p=0.95,
24
+ temperature=0.7,
25
+ pad_token_id=tokenizer.eos_token_id
26
  )
27
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+
29
+ # Wyciągnij tylko nową odpowiedź
30
+ answer = decoded[len(history_text):].split("Użytkownik:")[0].strip()
31
+ history.append((prompt, answer))
32
+ return answer, history
33
 
34
  # Gradio UI
35
+ gr.ChatInterface(fn=chatbot, title="🤖 Polski Chatbot AI", description="Model: radlab/polish-gpt2-small-v2").launch()