SarowarSaurav commited on
Commit
e93de88
·
verified ·
1 Parent(s): f1e4bf8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -0
app.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
9
+
10
+ MAX_MAX_NEW_TOKENS = 1024
11
+ DEFAULT_MAX_NEW_TOKENS = 512
12
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
+
14
+ DESCRIPTION = """\
15
+ # Try Patched Chat
16
+ """
17
+
18
+ LICENSE = """\
19
+ ---
20
+ This space was created by [patched](https://patched.codes).
21
+ """
22
+
23
+ if not torch.cuda.is_available():
24
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
25
+
26
+
27
+ if torch.cuda.is_available():
28
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
29
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
30
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
31
+ tokenizer.padding_side = 'right'
32
+ # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
33
+ # tokenizer.use_default_system_prompt = FalseQwen/CodeQwen1.5-7B-Chat
34
+
35
+ @spaces.GPU(duration=60)
36
+ def generate(
37
+ message: str,
38
+ chat_history: list[tuple[str, str]],
39
+ system_prompt: str,
40
+ max_new_tokens: int = 1024,
41
+ temperature: float = 0.2,
42
+ top_p: float = 0.95,
43
+ # top_k: int = 50,
44
+ # repetition_penalty: float = 1.2,
45
+ ) -> Iterator[str]:
46
+ conversation = []
47
+ if system_prompt:
48
+ conversation.append({"role": "system", "content": system_prompt})
49
+ for user, assistant in chat_history:
50
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
51
+ conversation.append({"role": "user", "content": message})
52
+
53
+ # prompt = pipe.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
54
+ # outputs = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p,
55
+ # eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
56
+
57
+ # return outputs[0]['generated_text'][len(prompt):].strip()
58
+
59
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
60
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
61
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
62
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
63
+ input_ids = input_ids.to(model.device)
64
+
65
+ terminators = [
66
+ tokenizer.eos_token_id,
67
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
68
+ ]
69
+
70
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
71
+ generate_kwargs = dict(
72
+ {"input_ids": input_ids},
73
+ streamer=streamer,
74
+ max_new_tokens=max_new_tokens,
75
+ do_sample=True,
76
+ top_p=top_p,
77
+ #top_k=top_k,
78
+ temperature=temperature,
79
+ eos_token_id=terminators,
80
+ #eos_token_id=tokenizer.eos_token_id,
81
+ #pad_token_id=tokenizer.pad_token_id,
82
+ #num_beams=1,
83
+ #repetition_penalty=1.2,
84
+ )
85
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
86
+ t.start()
87
+
88
+ outputs = []
89
+ for text in streamer:
90
+ outputs.append(text)
91
+ yield "".join(outputs)
92
+
93
+ example1='''Fix vulnerability CWE-327: Use of a Broken or Risky Cryptographic Algorithm in the following code snippet.
94
+
95
+ def md5_hash(path):
96
+ with open(path, "rb") as f:
97
+ content = f.read()
98
+ return hashlib.md5(content).hexdigest()
99
+ '''
100
+
101
+ example2='''Carefully analyze the given old code and new code and generate a summary of the changes.
102
+
103
+ Old Code:
104
+
105
+ #include <stdio.h>
106
+ #include <stdlib.h>
107
+
108
+ typedef struct Node {
109
+ int data;
110
+ struct Node *next;
111
+ } Node;
112
+
113
+ void processList() {
114
+ Node *head = (Node*)malloc(sizeof(Node));
115
+ head->data = 1;
116
+ head->next = (Node*)malloc(sizeof(Node));
117
+ head->next->data = 2;
118
+
119
+ printf("First element: %d\n", head->data);
120
+
121
+ free(head->next);
122
+ free(head);
123
+
124
+ printf("Accessing freed list: %d\n", head->next->data);
125
+ }
126
+
127
+ New Code:
128
+
129
+ #include <stdio.h>
130
+ #include <stdlib.h>
131
+
132
+ typedef struct Node {
133
+ int data;
134
+ struct Node *next;
135
+ } Node;
136
+
137
+ void processList() {
138
+ Node *head = (Node*)malloc(sizeof(Node));
139
+ if (head == NULL) {
140
+ perror("Failed to allocate memory for head");
141
+ return;
142
+ }
143
+
144
+ head->data = 1;
145
+ head->next = (Node*)malloc(sizeof(Node));
146
+ if (head->next == NULL) {
147
+ free(head);
148
+ perror("Failed to allocate memory for next node");
149
+ return;
150
+ }
151
+ head->next->data = 2;
152
+
153
+ printf("First element: %d\n", head->data);
154
+
155
+ free(head->next);
156
+ head->next = NULL;
157
+ free(head);
158
+ head = NULL;
159
+
160
+ if (head != NULL && head->next != NULL) {
161
+ printf("Accessing freed list: %d\n", head->next->data);
162
+ }
163
+ }
164
+ '''
165
+
166
+ example3='''Is the following code prone to CWE-117: Improper Output Neutralization for Logs. Respond only with YES or NO.
167
+
168
+ from flask import Flask, request, jsonify
169
+ import logging
170
+
171
+ app = Flask(__name__)
172
+ logging.basicConfig(level=logging.INFO)
173
+ logger = logging.getLogger(__name__)
174
+
175
+ @app.route('/api/data', methods=['GET'])
176
+ def get_data():
177
+ api_key = request.args.get('api_key')
178
+ logger.info("Received request with API Key: %s", api_key)
179
+ data = {"message": "Data processed"}
180
+ return jsonify(data)
181
+ '''
182
+
183
+ example4='''Fix vulnerability CWE-78: Improper Neutralization of Special Elements used in an OS Command ('OS Command Injection') in the following code snippet.
184
+
185
+ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
186
+ if desc is not None:
187
+ print(desc)
188
+ run_kwargs = {{
189
+ "args": command,
190
+ "shell": True,
191
+ "env": os.environ if custom_env is None else custom_env,
192
+ "encoding": 'utf8',
193
+ "errors": 'ignore',
194
+ }}
195
+ if not live:
196
+ run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
197
+ result = subprocess.run(**run_kwargs) ##here
198
+ if result.returncode != 0:
199
+ error_bits = [
200
+ f"{{errdesc or 'Error running command'}}.",
201
+ f"Command: {{command}}",
202
+ f"Error code: {{result.returncode}}",
203
+ ]
204
+ if result.stdout:
205
+ error_bits.append(f"stdout: {{result.stdout}}")
206
+ if result.stderr:
207
+ error_bits.append(f"stderr: {{result.stderr}}")
208
+ raise RuntimeError("\n".join(error_bits))
209
+ return (result.stdout or "")
210
+ '''
211
+
212
+ chat_interface = gr.ChatInterface(
213
+ fn=generate,
214
+ chatbot=gr.Chatbot(height="480px"),
215
+ additional_inputs=[
216
+ gr.Textbox(label="System prompt", lines=4),
217
+ gr.Slider(
218
+ label="Max new tokens",
219
+ minimum=1,
220
+ maximum=MAX_MAX_NEW_TOKENS,
221
+ step=1,
222
+ value=DEFAULT_MAX_NEW_TOKENS,
223
+ ),
224
+ gr.Slider(
225
+ label="Temperature",
226
+ minimum=0.1,
227
+ maximum=4.0,
228
+ step=0.1,
229
+ value=0.2,
230
+ ),
231
+ gr.Slider(
232
+ label="Top-p (nucleus sampling)",
233
+ minimum=0.05,
234
+ maximum=1.0,
235
+ step=0.05,
236
+ value=0.95,
237
+ ),
238
+ ],
239
+ stop_btn=None,
240
+ examples=[
241
+ ["You are a helpful coding assistant. Create a snake game in Python."],
242
+ [example1],
243
+ [example2],
244
+ [example3],
245
+ [example4],
246
+ ],
247
+ )
248
+
249
+ with gr.Blocks(css="style.css",) as demo:
250
+ gr.Markdown(DESCRIPTION)
251
+ chat_interface.render()
252
+ gr.Markdown(LICENSE, elem_classes="contain")
253
+
254
+ if __name__ == "__main__":
255
+ demo.queue(max_size=20).launch()