File size: 1,848 Bytes
49674c0
e8be70a
49674c0
 
e8be70a
49674c0
 
 
 
 
 
 
 
e8be70a
49674c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67d74e3
49674c0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Fichier app.py
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Configuration du modèle
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
    "soynade-research/Oolel-v0.1",
    torch_dtype=torch.bfloat16,
    device_map="auto" if torch.cuda.is_available() else None
)
tokenizer = AutoTokenizer.from_pretrained("soynade-research/Oolel-v0.1")

def generate_response(messages, max_new_tokens=1024, temperature=0.1):
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device)
    generated_ids = model.generate(
        model_inputs.input_ids, 
        max_new_tokens=max_new_tokens, 
        temperature=temperature
    )
    
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

# Configuration de l'interface Gradio
def chat_interface(message, history):
    # Convertir l'historique de Gradio au format requis par le modèle
    formatted_history = [
        {"role": "user" if idx % 2 == 0 else "assistant", "content": msg} 
        for idx, msg in enumerate(sum(history, []))
    ]
    
    # Ajouter le nouveau message
    formatted_history.append({"role": "user", "content": message})
    
    # Générer la réponse
    response = generate_response(formatted_history)
    
    return response

# Créer l'interface Gradio
iface = gr.ChatInterface(
    fn=chat_interface,
    title="Chat avec Oolel",
    description="Conversez avec le modèle Oolel",
    type="messages"
)

if __name__ == "__main__":
    iface.launch()