Norod78 commited on
Commit
d149148
·
1 Parent(s): cca4dd4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install git+https://github.com/huggingface/transformers
2
+
3
+ import gradio as gr
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
+ from threading import Thread
6
+
7
+ tok = AutoTokenizer.from_pretrained("distilgpt2")
8
+ model = AutoModelForCausalLM.from_pretrained("distilgpt2")
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
12
+ model.to(device)
13
+
14
+ early_stop_pattern = tok.eos_token
15
+ print(f'Early stop pattern = \"{early_stop_pattern}\"')
16
+
17
+ def generate(text = ""):
18
+ streamer = TextIteratorStreamer(tok)
19
+ if len(text) == 0:
20
+ text = " "
21
+ inputs = tok([text], return_tensors="pt")
22
+ generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.0, do_sample=True, top_k=40, top_p=0.97, max_new_tokens=128)
23
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
24
+ thread.start()
25
+ generated_text = ""
26
+ for new_text in streamer:
27
+ yield generated_text + new_text
28
+ #print(new_text, end ="")
29
+ generated_text += new_text
30
+ if early_stop_pattern in generated_text:
31
+ generated_text = generated_text[: generated_text.find(early_stop_pattern) if early_stop_pattern else None]
32
+ streamer.end()
33
+ #print("\n--\n")
34
+ yield generated_text
35
+ return
36
+
37
+ demo = gr.Interface(
38
+ title="TextIteratorStreamer + Gradio demo",
39
+ fn=generate,
40
+ inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
41
+ outputs=gr.outputs.Textbox(label="Generated Text"),
42
+ )
43
+
44
+ demo.queue()
45
+ demo.launch()