zjXu11 commited on
Commit
b1ea0d1
·
verified ·
1 Parent(s): 9d9cd7e

Upload openai_utils.py

Browse files
Files changed (1) hide show
  1. utils/openai_utils.py +161 -0
utils/openai_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ from typing import Any
5
+ from aiohttp import ClientSession
6
+ from tqdm.asyncio import tqdm_asyncio
7
+ import random
8
+ from time import sleep
9
+ import sys
10
+ import aiolimiter
11
+
12
+ import openai
13
+ from openai import AsyncOpenAI, OpenAIError
14
+
15
+
16
+ def prepare_message(SYSTEM_INPUT, USER_INPUT):
17
+ cur_message = [
18
+ {
19
+ "role": "system",
20
+ "content": SYSTEM_INPUT
21
+ },
22
+ {
23
+ "role": "user",
24
+ "content": USER_INPUT,
25
+ }
26
+ ]
27
+ return cur_message
28
+
29
+ def prepare_remove_message(USER_INPUT):
30
+ cur_message = [
31
+ {
32
+ "role": "system",
33
+ "content": "Remove sentences about experimental design and results: "
34
+ },
35
+ {
36
+ "role": "user",
37
+ "content": USER_INPUT,
38
+ }
39
+ ]
40
+ return cur_message
41
+
42
+ def prepare_generation_input(title, abstract, sections, filepath):
43
+ with open(filepath, 'r', encoding='utf-8') as file:
44
+ SYSTEM_INPUT=file.read()
45
+ return SYSTEM_INPUT,f"Paper title: {title}\n\nPaper abstract: {abstract}\n\nPaper Sections: {sections}"
46
+
47
+ def prepare_remove_input(title, abstract, introduction, filepath):
48
+ with open(filepath,'r',encoding='utf-8') as file:
49
+ SYSTEM_INPUT=file.read()
50
+ print(SYSTEM_INPUT)
51
+ return SYSTEM_INPUT,f"Paper title: {title}\n\nPaper abstract: {abstract}\n\nIntroduction: {introduction}\n\n"
52
+
53
+
54
+ async def _throttled_openai_chat_completion_acreate(
55
+ client: AsyncOpenAI,
56
+ model: str,
57
+ messages,
58
+ temperature: float,
59
+ max_tokens: int,
60
+ top_p: float,
61
+ limiter: aiolimiter.AsyncLimiter,
62
+ response_format: dict = {},
63
+ ):
64
+ async with limiter:
65
+ for _ in range(10):
66
+ try:
67
+ if response_format["type"] == "text":
68
+ return await client.chat.completions.create(
69
+ model=model,
70
+ messages=messages,
71
+ temperature=temperature,
72
+ max_tokens=max_tokens,
73
+ top_p=top_p,
74
+ )
75
+ else:
76
+ return await client.chat.completions.create(
77
+ model=model,
78
+ messages=messages,
79
+ temperature=temperature,
80
+ max_tokens=max_tokens,
81
+ top_p=top_p,
82
+ response_format=response_format,
83
+ )
84
+ except openai.BadRequestError as e:
85
+ print(e)
86
+ return None
87
+ except OpenAIError as e:
88
+ print(e)
89
+ sleep(random.randint(5, 10))
90
+ return None
91
+
92
+
93
+ async def generate_from_openai_chat_completion(
94
+ client,
95
+ messages,
96
+ engine_name: str,
97
+ temperature: float = 1.0,
98
+ max_tokens: int = 512,
99
+ top_p: float = 1.0,
100
+ requests_per_minute: int = 100,
101
+ response_format: dict = {"type":"text"},
102
+ ):
103
+ """Generate from OpenAI Chat Completion API.
104
+ Args:
105
+ messages: List of messages to proceed.
106
+ engine_name: Engine name to use, see https://platform.openai.com/docs/models
107
+ temperature: Temperature to use.
108
+ max_tokens: Maximum number of tokens to generate.
109
+ top_p: Top p to use.
110
+ requests_per_minute: Number of requests per minute to allow.
111
+ Returns:
112
+ List of generated responses.
113
+ """
114
+ limiter = aiolimiter.AsyncLimiter(requests_per_minute)
115
+
116
+ async_responses = [
117
+ _throttled_openai_chat_completion_acreate(
118
+ client,
119
+ model=engine_name,
120
+ messages=message,
121
+ temperature=temperature,
122
+ max_tokens=max_tokens,
123
+ top_p=top_p,
124
+ limiter=limiter,
125
+ response_format=response_format,
126
+ )
127
+ for message in messages
128
+ ]
129
+
130
+ responses = await tqdm_asyncio.gather(*async_responses, file=sys.stdout)
131
+
132
+ outputs = []
133
+ for response in responses:
134
+ if response:
135
+ outputs.append(response.choices[0].message.content)
136
+ else:
137
+ outputs.append("Invalid Message")
138
+ return outputs
139
+
140
+
141
+ # Example usage
142
+ if __name__ == "__main__":
143
+ os.environ["OPENAI_API_KEY"] = "xxx" # Set your OpenAI API key here
144
+
145
+ client = AsyncOpenAI()
146
+ AsyncOpenAI.api_key = os.getenv('OPENAI_API_KEY')
147
+
148
+ messages = [
149
+ {"role": "system", "content": "You are a helpful assistant."},
150
+ {"role": "user", "content": "What is the purpose of life? Output result in json format."},
151
+ ]
152
+ responses = asyncio.run(
153
+ generate_from_openai_chat_completion(
154
+ client,
155
+ messages=[messages]*50,
156
+ engine_name="gpt-3.5-turbo-0125",
157
+ max_tokens=256,
158
+ response_format={"type":"json_object"},
159
+ )
160
+ )
161
+ print(responses)