hiert commited on
Commit
e69a170
·
verified ·
1 Parent(s): e856333

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py CHANGED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import json
4
+ import os
5
+ from screenshot import (
6
+ before_prompt,
7
+ prompt_to_generation,
8
+ after_generation,
9
+ js_save,
10
+ js_load_script,
11
+ )
12
+ from spaces_info import description, examples, initial_prompt_value
13
+
14
+ API_URL = os.getenv("API_URL")
15
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
16
+
17
+
18
+ def query(payload):
19
+ print(payload)
20
+ response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
21
+ print(response)
22
+ return json.loads(response.content.decode("utf-8"))
23
+
24
+
25
+ def inference(input_sentence, max_length, sample_or_greedy, seed=42):
26
+ if sample_or_greedy == "Sample":
27
+ parameters = {
28
+ "max_new_tokens": max_length,
29
+ "top_p": 0.9,
30
+ "do_sample": True,
31
+ "seed": seed,
32
+ "early_stopping": False,
33
+ "length_penalty": 0.0,
34
+ "eos_token_id": None,
35
+ }
36
+ else:
37
+ parameters = {
38
+ "max_new_tokens": max_length,
39
+ "do_sample": False,
40
+ "seed": seed,
41
+ "early_stopping": False,
42
+ "length_penalty": 0.0,
43
+ "eos_token_id": None,
44
+ }
45
+
46
+ payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} }
47
+
48
+ data = query(payload)
49
+
50
+ if "error" in data:
51
+ return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>")
52
+
53
+ generation = data[0]["generated_text"].split(input_sentence, 1)[1]
54
+ return (
55
+ before_prompt
56
+ + input_sentence
57
+ + prompt_to_generation
58
+ + generation
59
+ + after_generation,
60
+ data[0]["generated_text"],
61
+ "",
62
+ )
63
+
64
+
65
+ if __name__ == "__main__":
66
+ demo = gr.Blocks()
67
+ with demo:
68
+ with gr.Row():
69
+ gr.Markdown(value=description)
70
+ with gr.Row():
71
+ with gr.Column():
72
+ text = gr.Textbox(
73
+ label="Input",
74
+ value=" ", # should be set to " " when plugged into a real API
75
+ )
76
+ tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate")
77
+ sampling = gr.Radio(
78
+ ["Sample", "Greedy"], label="Sample or greedy", value="Sample"
79
+ )
80
+ sampling2 = gr.Radio(
81
+ ["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"],
82
+ value="Sample 1",
83
+ label="Sample other generations (only work in 'Sample' mode)",
84
+ type="index",
85
+ )
86
+
87
+ with gr.Row():
88
+ submit = gr.Button("Submit")
89
+ load_image = gr.Button("Generate Image")
90
+ with gr.Column():
91
+ text_error = gr.Markdown(label="Log information")
92
+ text_out = gr.Textbox(label="Output")
93
+ display_out = gr.HTML(label="Image")
94
+ display_out.set_event_trigger(
95
+ "load",
96
+ fn=None,
97
+ inputs=None,
98
+ outputs=None,
99
+ no_target=True,
100
+ js=js_load_script,
101
+ )
102
+ with gr.Row():
103
+ gr.Examples(examples=examples, inputs=[text, tokens, sampling, sampling2])
104
+
105
+ submit.click(
106
+ inference,
107
+ inputs=[text, tokens, sampling, sampling2],
108
+ outputs=[display_out, text_out, text_error],
109
+ )
110
+
111
+ load_image.click(fn=None, inputs=None, outputs=None, _js=js_save)
112
+
113
+ demo.launch()