Spaces:
Runtime error
Runtime error
File size: 2,459 Bytes
5459aa9 f863056 5459aa9 57579dd 5459aa9 57579dd 5459aa9 06ea162 f372d1e 57579dd 06ea162 5459aa9 57579dd 5459aa9 57579dd 5459aa9 06ea162 5459aa9 57579dd 5459aa9 06ea162 57579dd 5459aa9 57579dd 2909fb3 57579dd db924a3 57579dd 5459aa9 57579dd 5459aa9 57579dd 54e1be2 a727207 57579dd 5459aa9 57579dd a727207 4429d9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
# Konfigurationsparameter
MAX_MAX_NEW_TOKENS = 100
DEFAULT_MAX_NEW_TOKENS = 20
MAX_INPUT_TOKEN_LENGTH = 400 # Begrenzung auf 400 Tokens
# Modell und Tokenizer laden
model_id = "Loewolf/GPT_1"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Gradio Chat Interface Funktion
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> str:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH)
generate_kwargs = dict(
input_ids=input_ids["input_ids"],
max_length=input_ids["input_ids"].shape[1] + max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.eos_token_id
)
outputs = model.generate(**generate_kwargs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Gradio Interface
chat_interface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Message"),
gr.JSON(label="Chat History"),
gr.Textbox(label="System Prompt", lines=2),
gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.6),
gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
],
outputs="text",
live=True
)
# Starten des Gradio-Servers
if __name__ == "__main__":
chat_interface.launch()
|