File size: 2,459 Bytes
5459aa9
 
 
 
f863056
5459aa9
57579dd
5459aa9
57579dd
5459aa9
 
06ea162
f372d1e
57579dd
 
06ea162
 
5459aa9
57579dd
5459aa9
 
 
 
57579dd
5459aa9
 
 
 
06ea162
5459aa9
 
 
 
 
 
 
57579dd
5459aa9
06ea162
 
57579dd
5459aa9
 
 
57579dd
2909fb3
 
57579dd
 
db924a3
57579dd
 
5459aa9
57579dd
 
 
 
 
 
 
 
 
5459aa9
57579dd
 
54e1be2
a727207
57579dd
5459aa9
57579dd
 
a727207
4429d9b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer

# Konfigurationsparameter
MAX_MAX_NEW_TOKENS = 100
DEFAULT_MAX_NEW_TOKENS = 20
MAX_INPUT_TOKEN_LENGTH = 400  # Begrenzung auf 400 Tokens

# Modell und Tokenizer laden
model_id = "Loewolf/GPT_1"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Gradio Chat Interface Funktion
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> str:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH)
    generate_kwargs = dict(
        input_ids=input_ids["input_ids"],
        max_length=input_ids["input_ids"].shape[1] + max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        pad_token_id=tokenizer.eos_token_id
    )

    outputs = model.generate(**generate_kwargs)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Gradio Interface
chat_interface = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(label="Message"),
        gr.JSON(label="Chat History"),
        gr.Textbox(label="System Prompt", lines=2),
        gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
        gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.6),
        gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
        gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
        gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
    ],
    outputs="text",
    live=True
)

# Starten des Gradio-Servers
if __name__ == "__main__":
    chat_interface.launch()