Pgohari commited on
Commit
7598a52
·
verified ·
1 Parent(s): f0f5e80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -8
app.py CHANGED
@@ -1,16 +1,33 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  import gradio as gr
3
 
4
- model_name = "albert-base-v2"
 
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
7
 
 
8
  def chatbot(user_input):
9
- inputs = tokenizer(user_input, return_tensors="pt")
10
- outputs = model(**inputs)
11
- response = outputs.logits.argmax(-1).item()
12
- return f"Predicted response: {response}"
 
 
 
 
 
 
 
13
 
14
- demo = gr.Interface(fn=chatbot, inputs="text", outputs="text", title="ALBERT Chatbot")
 
 
 
 
 
 
 
 
15
  demo.launch()
16
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import gradio as gr
3
 
4
+ # Charger le modèle mT5-Small multilingue
5
+ model_name = "google/mt5-small"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
 
9
+ # Définir la fonction du chatbot
10
  def chatbot(user_input):
11
+ # Tokeniser l'entrée utilisateur
12
+ inputs = tokenizer(user_input, return_tensors="pt", max_length=128, truncation=True)
13
+
14
+ # Générer une réponse avec le modèle
15
+ outputs = model.generate(inputs["input_ids"], max_length=50, num_beams=4, early_stopping=True)
16
+
17
+ # Décoder la réponse générée
18
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
19
+
20
+ # Retourner la réponse finale
21
+ return response
22
 
23
+ # Créer une interface Gradio pour tester le chatbot
24
+ demo = gr.Interface(
25
+ fn=chatbot,
26
+ inputs="text",
27
+ outputs="text",
28
+ title="Chatbot en français avec mT5-Small"
29
+ )
30
+
31
+ # Lancer l'application Gradio
32
  demo.launch()
33