Loewolf commited on
Commit
57579dd
·
1 Parent(s): dceb3ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -92
app.py CHANGED
@@ -1,43 +1,32 @@
1
  import os
 
2
  from threading import Thread
3
  from typing import Iterator
4
 
5
  import gradio as gr
6
  import spaces
7
- import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
 
10
  MAX_MAX_NEW_TOKENS = 100
11
  DEFAULT_MAX_NEW_TOKENS = 20
12
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "200"))
13
-
14
- DESCRIPTION = """\
15
- # Löwolf Chat
16
-
17
- """
18
-
19
- LICENSE = """
20
- <p/>
21
- ---
22
- """
23
-
24
- if not torch.cuda.is_available():
25
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
26
-
27
 
 
 
28
  if torch.cuda.is_available():
29
- model_id = "Loewolf/GPT_1"
30
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
31
  tokenizer = AutoTokenizer.from_pretrained(model_id)
32
- tokenizer.use_default_system_prompt = False
33
-
34
 
 
35
  @spaces.GPU
36
  def generate(
37
  message: str,
38
  chat_history: list[tuple[str, str]],
39
  system_prompt: str,
40
- max_new_tokens: int = 50,
41
  temperature: float = 0.6,
42
  top_p: float = 0.9,
43
  top_k: int = 50,
@@ -50,90 +39,42 @@ def generate(
50
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
51
  conversation.append({"role": "user", "content": message})
52
 
53
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
54
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
55
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
56
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
57
  input_ids = input_ids.to(model.device)
58
 
59
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
60
  generate_kwargs = dict(
61
- {"input_ids": input_ids},
62
- streamer=streamer,
63
- max_new_tokens=max_new_tokens,
64
- do_sample=True,
65
  top_p=top_p,
66
  top_k=top_k,
67
- temperature=temperature,
68
- num_beams=1,
69
  repetition_penalty=repetition_penalty,
 
70
  )
71
- t = Thread(target=model.generate, kwargs=generate_kwargs)
72
- t.start()
73
-
74
- outputs = []
75
- for text in streamer:
76
- outputs.append(text)
77
- yield "".join(outputs)
78
 
 
 
79
 
80
- chat_interface = gr.ChatInterface(
 
81
  fn=generate,
82
- additional_inputs=[
83
- gr.Textbox(label="System prompt", lines=6),
84
- gr.Slider(
85
- label="Max new tokens",
86
- minimum=1,
87
- maximum=MAX_MAX_NEW_TOKENS,
88
- step=1,
89
- value=DEFAULT_MAX_NEW_TOKENS,
90
- ),
91
- gr.Slider(
92
- label="Temperature",
93
- minimum=0.1,
94
- maximum=1.0,
95
- step=0.1,
96
- value=0.6,
97
- ),
98
- gr.Slider(
99
- label="Top-p (nucleus sampling)",
100
- minimum=0.05,
101
- maximum=1.0,
102
- step=0.05,
103
- value=0.9,
104
- ),
105
- gr.Slider(
106
- label="Top-k",
107
- minimum=1,
108
- maximum=1000,
109
- step=1,
110
- value=50,
111
- ),
112
- gr.Slider(
113
- label="Repetition penalty",
114
- minimum=1.0,
115
- maximum=2.0,
116
- step=0.05,
117
- value=1.2,
118
- ),
119
- ],
120
- stop_btn=None,
121
- examples=[
122
- ["Hello there! How are you doing?"],
123
- ["Can you explain briefly to me what is the Python programming language?"],
124
- ["Explain the plot of Cinderella in a sentence."],
125
- ["How many hours does it take a man to eat a Helicopter?"],
126
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
127
  ],
 
 
128
  )
129
 
130
- with gr.Blocks(css="style.css") as demo:
131
- gr.Markdown(DESCRIPTION)
132
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
133
- chat_interface.render()
134
- gr.Markdown(LICENSE)
135
-
136
  if __name__ == "__main__":
137
- demo.queue(max_size=20).launch()
 
138
 
139
 
 
1
  import os
2
+ import torch
3
  from threading import Thread
4
  from typing import Iterator
5
 
6
  import gradio as gr
7
  import spaces
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
9
 
10
+ # Konfigurationsparameter
11
  MAX_MAX_NEW_TOKENS = 100
12
  DEFAULT_MAX_NEW_TOKENS = 20
13
+ MAX_INPUT_TOKEN_LENGTH = 200 # Anpassung auf 400 Tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Modell und Tokenizer laden
16
+ model_id = "Loewolf/GPT_1"
17
  if torch.cuda.is_available():
 
18
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
19
  tokenizer = AutoTokenizer.from_pretrained(model_id)
20
+ else:
21
+ raise EnvironmentError("CUDA ist nicht verfügbar. Dieses Skript benötigt eine GPU.")
22
 
23
+ # Gradio Chat Interface Funktion
24
  @spaces.GPU
25
  def generate(
26
  message: str,
27
  chat_history: list[tuple[str, str]],
28
  system_prompt: str,
29
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
30
  temperature: float = 0.6,
31
  top_p: float = 0.9,
32
  top_k: int = 50,
 
39
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
40
  conversation.append({"role": "user", "content": message})
41
 
42
+ input_ids = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH)
 
 
 
43
  input_ids = input_ids.to(model.device)
44
 
 
45
  generate_kwargs = dict(
46
+ input_ids=input_ids,
47
+ max_new_tokens=min(max_new_tokens, MAX_MAX_NEW_TOKENS),
48
+ temperature=temperature,
 
49
  top_p=top_p,
50
  top_k=top_k,
 
 
51
  repetition_penalty=repetition_penalty,
52
+ pad_token_id=tokenizer.eos_token_id
53
  )
 
 
 
 
 
 
 
54
 
55
+ outputs = model.generate(**generate_kwargs)
56
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
57
 
58
+ # Gradio Interface
59
+ chat_interface = gr.Interface(
60
  fn=generate,
61
+ inputs=[
62
+ gr.Textbox(label="Message"),
63
+ gr.JSON(label="Chat History"),
64
+ gr.Textbox(label="System Prompt", lines=2),
65
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
66
+ gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.6),
67
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
68
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
69
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  ],
71
+ outputs="text",
72
+ live=True
73
  )
74
 
75
+ # Starten des Gradio-Servers
 
 
 
 
 
76
  if __name__ == "__main__":
77
+ chat_interface.launch()
78
+
79
 
80