Loewolf commited on
Commit
958d235
·
1 Parent(s): dd10e61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -3
app.py CHANGED
@@ -1,12 +1,44 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
 
 
3
 
4
  # Laden des GPT-Modells mit Hugging Face Pipeline
5
  model = pipeline("text-generation", model="Loewolf/GPT_1")
 
6
 
7
- # Definition einer Wrapper-Funktion für das Modell
8
  def generate_text(input_text):
9
- return model(input_text)[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Erstellen der Gradio-Schnittstelle
12
  interface = gr.Interface(fn=generate_text, inputs="text", outputs="text")
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import pipeline, set_seed
4
+
5
+ # Setzen eines Seeds für Reproduzierbarkeit
6
+ set_seed(42)
7
 
8
  # Laden des GPT-Modells mit Hugging Face Pipeline
9
  model = pipeline("text-generation", model="Loewolf/GPT_1")
10
+ tokenizer = model.tokenizer
11
 
 
12
  def generate_text(input_text):
13
+ # Konvertieren des Eingabetextes in Token-IDs
14
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
15
+
16
+ # Erstellung der Attention-Mask
17
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.bool)
18
+
19
+ # Einstellung der maximalen Länge
20
+ max_length = model.model.config.n_positions if len(input_ids[0]) > model.model.config.n_positions else len(input_ids[0]) + 20
21
+
22
+ # Textgenerierung mit spezifischen Parametern
23
+ beam_output = model.model.generate(
24
+ input_ids,
25
+ attention_mask=attention_mask,
26
+ max_length=max_length,
27
+ min_length=4,
28
+ num_beams=5,
29
+ no_repeat_ngram_size=2,
30
+ early_stopping=True,
31
+ temperature=0.9,
32
+ top_p=0.90,
33
+ top_k=50,
34
+ length_penalty=2.0,
35
+ do_sample=True,
36
+ eos_token_id=tokenizer.eos_token_id,
37
+ pad_token_id=tokenizer.eos_token_id
38
+ )
39
+
40
+ # Konvertieren der generierten Token-IDs zurück in Text
41
+ return tokenizer.decode(beam_output[0], skip_special_tokens=True)
42
 
43
  # Erstellen der Gradio-Schnittstelle
44
  interface = gr.Interface(fn=generate_text, inputs="text", outputs="text")