Alfasign commited on
Commit
707e859
·
1 Parent(s): a499ce9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -149
app.py CHANGED
@@ -1,151 +1,120 @@
1
- import streamlit as st
2
- import openai
3
- import re
4
- import csv
5
- import base64
6
- from io import StringIO
7
- import threading
8
- from queue import Queue
9
-
10
- st.title("EinfachChatProjekt")
11
-
12
- api_key = st.sidebar.text_input("API Key:", value="sk-")
13
- openai.api_key = api_key
14
-
15
- show_notes = st.sidebar.checkbox("Show Notes", value="TRUE")
16
- data_section = st.sidebar.text_area("CSV or Text Data:")
17
- paste_data = st.sidebar.button("Paste Data")
18
- num_concurrent_calls = st.sidebar.number_input("Concurrent Calls:", min_value=1, max_value=2000, value=50, step=1)
19
- generate_all = st.sidebar.button("Generate All")
20
- reset = st.sidebar.button("Reset")
21
- add_row = st.sidebar.button("Add row")
22
- model = st.sidebar.selectbox("Model:", ["gpt-4", "gpt-3.5-turbo"])
23
- temperature = st.sidebar.slider("Temperature:", 0.0, 1.0, 0.6, step=0.01)
24
- max_tokens = st.sidebar.number_input("Max Tokens:", min_value=1, max_value=8192, value=2000, step=1)
25
- top_p = st.sidebar.slider("Top P:", 0.0, 1.0, 1.0, step=0.01)
26
- system_message = st.sidebar.text_area("System Message:")
27
- row_count = st.session_state.get("row_count", 1)
28
-
29
- if add_row:
30
- row_count += 1
31
- st.session_state.row_count = row_count
32
-
33
- if paste_data:
34
- data = StringIO(data_section.strip())
35
- reader = csv.reader(data, delimiter='\n', quotechar='"')
36
- messages = [row[0] for row in reader]
37
- if show_notes:
38
- row_count = len(messages) // 2
39
- for i in range(row_count):
40
- st.session_state[f"note{i}"] = messages[i * 2]
41
- st.session_state[f"message{i}"] = messages[i * 2 + 1]
42
- else:
43
- row_count = len(messages)
44
- for i, message in enumerate(messages):
45
- st.session_state[f"message{i}"] = message
46
- st.session_state.row_count = row_count
47
-
48
- if reset:
49
- row_count = 1
50
- st.session_state.row_count = row_count
51
- for i in range(100): # Assuming a maximum of 100 rows
52
- st.session_state[f"note{i}"] = ""
53
- st.session_state[f"message{i}"] = ""
54
- st.session_state[f"response{i}"] = ""
55
- st.session_state[f"prompt_tokens{i}"] = 0
56
- st.session_state[f"response_tokens{i}"] = 0
57
- st.session_state[f"word_count{i}"] = 0
58
-
59
- def generate_response(i, message):
60
- try:
61
- completion = openai.ChatCompletion.create(
62
- model=model,
63
- messages=[
64
- {"role": "system", "content": system_message},
65
- {"role": "user", "content": message}
66
- ],
67
- temperature=temperature,
68
- max_tokens=max_tokens,
69
- top_p=top_p
70
  )
71
 
72
- response = completion.choices[0].message.content
73
- prompt_tokens = completion.usage['prompt_tokens']
74
- response_tokens = completion.usage['total_tokens'] - prompt_tokens
75
- word_count = len(re.findall(r'\w+', response))
76
-
77
- return (i, response, prompt_tokens, response_tokens, word_count)
78
-
79
- except Exception as e:
80
- return (i, str(e), 0, 0, 0)
81
-
82
- def worker(q, results):
83
- for item in iter(q.get, None):
84
- results.put(generate_response(*item))
85
-
86
- class WorkerThread(threading.Thread):
87
- def __init__(self, input_queue, output_queue):
88
- threading.Thread.__init__(self)
89
- self.input_queue = input_queue
90
- self.output_queue = output_queue
91
- self.daemon = True
92
-
93
- def run(self):
94
- while True:
95
- i, message = self.input_queue.get()
96
- try:
97
- result = generate_response(i, message)
98
- self.output_queue.put(result)
99
- finally:
100
- self.input_queue.task_done()
101
-
102
- if generate_all:
103
- jobs = Queue()
104
- results = Queue()
105
-
106
- workers = [WorkerThread(jobs, results) for _ in range(num_concurrent_calls)]
107
-
108
- for worker in workers:
109
- worker.start()
110
-
111
- for i in range(row_count):
112
- message = st.session_state.get(f"message{i}", "")
113
- jobs.put((i, message))
114
-
115
- jobs.join()
116
-
117
- while not results.empty():
118
- i, response, prompt_tokens, response_tokens, word_count = results.get()
119
- st.session_state[f"response{i}"] = response
120
- st.session_state[f"prompt_tokens{i}"] = prompt_tokens
121
- st.session_state[f"response_tokens{i}"] = response_tokens
122
- st.session_state[f"word_count{i}"] = word_count
123
-
124
- def create_download_link(text, filename):
125
- b64 = base64.b64encode(text.encode()).decode()
126
- href = f'<a href="data:file/txt;base64,{b64}" download="{filename}">Download {filename}</a>'
127
- return href
128
-
129
- for i in range(row_count):
130
- if show_notes:
131
- st.text_input(f"Note {i + 1}:", key=f"note{i}", value=st.session_state.get(f"note{i}", ""))
132
- col1, col2 = st.columns(2)
133
-
134
- with col1:
135
- message = st.text_area(f"Message {i + 1}:", key=f"message{i}", value=st.session_state.get(f"message{i}", ""))
136
-
137
- if st.button(f"Generate Response {i + 1}") and not st.session_state.get(f"response{i}", ""):
138
- response, prompt_tokens, response_tokens, word_count = generate_response(i, message)
139
- st.session_state[f"response{i}"] = response
140
- st.session_state[f"prompt_tokens{i}"] = prompt_tokens
141
- st.session_state[f"response_tokens{i}"] = response_tokens
142
- st.session_state[f"word_count{i}"] = word_count
143
-
144
- with col2:
145
- st.text_area(f"Response {i + 1}:", value=st.session_state.get(f"response{i}", ""))
146
- st.write(f"Tokens: {st.session_state.get(f'prompt_tokens{i}', 0)} / {st.session_state.get(f'response_tokens{i}', 0)} + Words: {st.session_state.get(f'word_count{i}', 0)}")
147
-
148
- responses_text = "\n\n".join([f"{st.session_state.get(f'note{i}', '')}\n{st.session_state.get(f'response{i}', '')}" for i in range(row_count) if show_notes] + [st.session_state.get(f"response{i}", "") for i in range(row_count) if not show_notes])
149
- download_filename = "GPT-4 Responses.txt"
150
- download_link = create_download_link(responses_text, download_filename)
151
- st.markdown(download_link, unsafe_allow_html=True)
 
1
+ from typing import Any, Dict, Tuple
2
+ import warnings
3
+
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from transformers import (
7
+ StoppingCriteria,
8
+ StoppingCriteriaList,
9
+ TextIteratorStreamer,
10
+ )
11
+
12
+
13
+ INSTRUCTION_KEY = "### Instruction:"
14
+ RESPONSE_KEY = "### Response:"
15
+ END_KEY = "### End"
16
+ INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
17
+ PROMPT_FOR_GENERATION_FORMAT = """{intro}
18
+ {instruction_key}
19
+ {instruction}
20
+ {response_key}
21
+ """.format(
22
+ intro=INTRO_BLURB,
23
+ instruction_key=INSTRUCTION_KEY,
24
+ instruction="{instruction}",
25
+ response_key=RESPONSE_KEY,
26
+ )
27
+
28
+
29
+ class InstructionTextGenerationPipeline:
30
+ def __init__(
31
+ self,
32
+ model_name,
33
+ torch_dtype=torch.bfloat16,
34
+ trust_remote_code=True,
35
+ use_auth_token=None,
36
+ ) -> None:
37
+ self.model = AutoModelForCausalLM.from_pretrained(
38
+ model_name,
39
+ torch_dtype=torch_dtype,
40
+ trust_remote_code=trust_remote_code,
41
+ use_auth_token=use_auth_token,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
 
44
+ tokenizer = AutoTokenizer.from_pretrained(
45
+ model_name,
46
+ trust_remote_code=trust_remote_code,
47
+ use_auth_token=use_auth_token,
48
+ )
49
+ if tokenizer.pad_token_id is None:
50
+ warnings.warn(
51
+ "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id."
52
+ )
53
+ tokenizer.pad_token = tokenizer.eos_token
54
+ tokenizer.padding_side = "left"
55
+ self.tokenizer = tokenizer
56
+
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ self.model.eval()
59
+ self.model.to(device=device, dtype=torch_dtype)
60
+
61
+ self.generate_kwargs = {
62
+ "temperature": 0.1,
63
+ "top_p": 0.92,
64
+ "top_k": 0,
65
+ "max_new_tokens": 1024,
66
+ "use_cache": True,
67
+ "do_sample": True,
68
+ "eos_token_id": self.tokenizer.eos_token_id,
69
+ "pad_token_id": self.tokenizer.pad_token_id,
70
+ "repetition_penalty": 1.1, # 1.0 means no penalty, > 1.0 means penalty, 1.2 from CTRL paper
71
+ }
72
+
73
+ def format_instruction(self, instruction):
74
+ return PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
75
+
76
+ def __call__(
77
+ self, instruction: str, **generate_kwargs: Dict[str, Any]
78
+ ) -> Tuple[str, str, float]:
79
+ s = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
80
+ input_ids = self.tokenizer(s, return_tensors="pt").input_ids
81
+ input_ids = input_ids.to(self.model.device)
82
+ gkw = {**self.generate_kwargs, **generate_kwargs}
83
+ with torch.no_grad():
84
+ output_ids = self.model.generate(input_ids, **gkw)
85
+ # Slice the output_ids tensor to get only new tokens
86
+ new_tokens = output_ids[0, len(input_ids[0]) :]
87
+ output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
88
+ return output_text
89
+
90
+ # Initialize the model and tokenizer
91
+ generate = InstructionTextGenerationPipeline(
92
+ "mosaicml/mpt-7b-instruct",
93
+ torch_dtype=torch.bfloat16,
94
+ trust_remote_code=True,
95
+ )
96
+ stop_token_ids = generate.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
97
+
98
+
99
+ # Define a custom stopping criteria
100
+ class StopOnTokens(StoppingCriteria):
101
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
102
+ for stop_id in stop_token_ids:
103
+ if input_ids[0][-1] == stop_id:
104
+ return True
105
+ return False
106
+
107
+ """### The prompt & response"""
108
+
109
+ import json
110
+ import textwrap
111
+
112
+ def get_prompt(instruction):
113
+ prompt_template = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:"
114
+ return prompt_template
115
+
116
+ # print(get_prompt('What is the meaning of life?'))
117
+
118
+ def parse_text(text):
119
+ wrapped_text = textwrap.fill(text, width=100)
120
+ print(wrapped_text +'\n\n')