Loewolf commited on
Commit
ad42d55
·
1 Parent(s): 7211673

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -25
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  from transformers import pipeline, set_seed
3
 
4
  # Setzen eines Seeds für Reproduzierbarkeit
@@ -8,46 +9,62 @@ set_seed(42)
8
  model = pipeline("text-generation", model="Loewolf/GPT_1")
9
  tokenizer = model.tokenizer
10
 
11
- def generate_text(input_text, temperature, top_k, top_p, length, system_prompt):
12
- # Anpassen des Eingabetextes mit System-Prompt, falls vorhanden
13
- adjusted_input_text = system_prompt + input_text if system_prompt else input_text
14
-
15
  # Konvertieren des Eingabetextes in Token-IDs
16
- input_ids = tokenizer.encode(adjusted_input_text, return_tensors="pt")
 
 
 
17
 
18
  # Einstellung der maximalen Länge
19
- max_length = length if length else model.model.config.n_positions
20
 
21
  # Textgenerierung mit spezifischen Parametern
22
- output = model.model.generate(
23
  input_ids,
 
24
  max_length=max_length,
25
- temperature=temperature,
26
- top_k=top_k,
27
- top_p=top_p,
28
  no_repeat_ngram_size=2,
 
 
 
 
 
 
 
29
  pad_token_id=tokenizer.eos_token_id
30
  )
31
 
32
  # Konvertieren der generierten Token-IDs zurück in Text
33
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
34
 
35
  # Erstellen der Gradio-Schnittstelle
36
- interface = gr.Interface(
37
- fn=generate_text,
38
- inputs=[
39
- gr.inputs.Textbox(lines=2, placeholder="Geben Sie Ihren Text hier ein..."),
40
- gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.9, label="Temperature"),
41
- gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top K"),
42
- gr.inputs.Slider(minimum=0.1, maximum=1.0, step=0.1, default=0.9, label="Top P"),
43
- gr.inputs.Number(default=50, label="Länge"),
44
- gr.inputs.Textbox(lines=2, placeholder="System-Prompt (optional)")
45
- ],
46
- outputs="text",
47
- layout="vertical"
48
- )
 
 
 
49
 
50
  # Starten der Gradio-App
51
- interface.launch()
52
 
53
 
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import pipeline, set_seed
4
 
5
  # Setzen eines Seeds für Reproduzierbarkeit
 
9
  model = pipeline("text-generation", model="Loewolf/GPT_1")
10
  tokenizer = model.tokenizer
11
 
12
+ def generate_text(input_text, temp, top_k, top_p, length):
 
 
 
13
  # Konvertieren des Eingabetextes in Token-IDs
14
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
15
+
16
+ # Erstellung der Attention-Mask
17
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.bool)
18
 
19
  # Einstellung der maximalen Länge
20
+ max_length = model.model.config.n_positions if len(input_ids[0]) > model.model.config.n_positions else len(input_ids[0]) + length
21
 
22
  # Textgenerierung mit spezifischen Parametern
23
+ beam_output = model.model.generate(
24
  input_ids,
25
+ attention_mask=attention_mask,
26
  max_length=max_length,
27
+ min_length=4,
28
+ num_beams=5,
 
29
  no_repeat_ngram_size=2,
30
+ early_stopping=True,
31
+ temperature=temp,
32
+ top_p=top_p,
33
+ top_k=top_k,
34
+ length_penalty=2.0,
35
+ do_sample=True,
36
+ eos_token_id=tokenizer.eos_token_id,
37
  pad_token_id=tokenizer.eos_token_id
38
  )
39
 
40
  # Konvertieren der generierten Token-IDs zurück in Text
41
+ return tokenizer.decode(beam_output[0], skip_special_tokens=True)
42
+
43
+ def chat_with_model(user_input, history, temperature, top_k, top_p, length, system_prompt):
44
+ combined_input = f"{history}\nNutzer: {user_input}\n{system_prompt}:"
45
+ response = generate_text(combined_input, temperature, top_k, top_p, length)
46
+ new_history = f"{combined_input}\n{response}"
47
+ return "", new_history # Leerer String für user_input, um das Eingabefeld zurückzusetzen
48
 
49
  # Erstellen der Gradio-Schnittstelle
50
+ with gr.Blocks() as demo:
51
+ with gr.Row():
52
+ history = gr.Textbox(label="Chatverlauf", lines=10, interactive=False)
53
+ user_input = gr.Textbox(label="Deine Nachricht")
54
+ system_prompt = gr.Textbox(label="System Prompt", value="Löwolf GPT")
55
+ with gr.Column(scale=1):
56
+ temperature = gr.Slider(minimum=0, maximum=1, step=0.01, label="Temperature", value=0.9)
57
+ top_k = gr.Slider(minimum=0, maximum=100, step=1, label="Top K", value=50)
58
+ top_p = gr.Slider(minimum=0, maximum=1, step=0.01, label="Top P", value=0.9)
59
+ length = gr.Slider(minimum=1, maximum=100, step=1, label="Länge", value=20)
60
+ submit_btn = gr.Button("Senden")
61
+ submit_btn.click(
62
+ chat_with_model,
63
+ inputs=[user_input, history, temperature, top_k, top_p, length, system_prompt],
64
+ outputs=[user_input, history]
65
+ )
66
 
67
  # Starten der Gradio-App
68
+ demo.launch()
69
 
70