|
import asyncio
|
|
import logging
|
|
import os
|
|
from typing import Any
|
|
from aiohttp import ClientSession
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
import random
|
|
from time import sleep
|
|
import sys
|
|
import aiolimiter
|
|
|
|
import openai
|
|
from openai import AsyncOpenAI, OpenAIError
|
|
|
|
|
|
def prepare_message(SYSTEM_INPUT, USER_INPUT):
|
|
cur_message = [
|
|
{
|
|
"role": "system",
|
|
"content": SYSTEM_INPUT
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": USER_INPUT,
|
|
}
|
|
]
|
|
return cur_message
|
|
|
|
def prepare_remove_message(USER_INPUT):
|
|
cur_message = [
|
|
{
|
|
"role": "system",
|
|
"content": "Remove sentences about experimental design and results: "
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": USER_INPUT,
|
|
}
|
|
]
|
|
return cur_message
|
|
|
|
def prepare_generation_input(title, abstract, sections, filepath):
|
|
with open(filepath, 'r', encoding='utf-8') as file:
|
|
SYSTEM_INPUT=file.read()
|
|
return SYSTEM_INPUT,f"Paper title: {title}\n\nPaper abstract: {abstract}\n\nPaper Sections: {sections}"
|
|
|
|
def prepare_remove_input(title, abstract, introduction, filepath):
|
|
with open(filepath,'r',encoding='utf-8') as file:
|
|
SYSTEM_INPUT=file.read()
|
|
print(SYSTEM_INPUT)
|
|
return SYSTEM_INPUT,f"Paper title: {title}\n\nPaper abstract: {abstract}\n\nIntroduction: {introduction}\n\n"
|
|
|
|
|
|
async def _throttled_openai_chat_completion_acreate(
|
|
client: AsyncOpenAI,
|
|
model: str,
|
|
messages,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
top_p: float,
|
|
limiter: aiolimiter.AsyncLimiter,
|
|
response_format: dict = {},
|
|
):
|
|
async with limiter:
|
|
for _ in range(10):
|
|
try:
|
|
if response_format["type"] == "text":
|
|
return await client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
)
|
|
else:
|
|
return await client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
response_format=response_format,
|
|
)
|
|
except openai.BadRequestError as e:
|
|
print(e)
|
|
return None
|
|
except OpenAIError as e:
|
|
print(e)
|
|
sleep(random.randint(5, 10))
|
|
return None
|
|
|
|
|
|
async def generate_from_openai_chat_completion(
|
|
client,
|
|
messages,
|
|
engine_name: str,
|
|
temperature: float = 1.0,
|
|
max_tokens: int = 512,
|
|
top_p: float = 1.0,
|
|
requests_per_minute: int = 100,
|
|
response_format: dict = {"type":"text"},
|
|
):
|
|
"""Generate from OpenAI Chat Completion API.
|
|
Args:
|
|
messages: List of messages to proceed.
|
|
engine_name: Engine name to use, see https://platform.openai.com/docs/models
|
|
temperature: Temperature to use.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
top_p: Top p to use.
|
|
requests_per_minute: Number of requests per minute to allow.
|
|
Returns:
|
|
List of generated responses.
|
|
"""
|
|
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
|
|
|
|
async_responses = [
|
|
_throttled_openai_chat_completion_acreate(
|
|
client,
|
|
model=engine_name,
|
|
messages=message,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
limiter=limiter,
|
|
response_format=response_format,
|
|
)
|
|
for message in messages
|
|
]
|
|
|
|
responses = await tqdm_asyncio.gather(*async_responses, file=sys.stdout)
|
|
|
|
outputs = []
|
|
for response in responses:
|
|
if response:
|
|
outputs.append(response.choices[0].message.content)
|
|
else:
|
|
outputs.append("Invalid Message")
|
|
return outputs
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
os.environ["OPENAI_API_KEY"] = "xxx"
|
|
|
|
client = AsyncOpenAI()
|
|
AsyncOpenAI.api_key = os.getenv('OPENAI_API_KEY')
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "What is the purpose of life? Output result in json format."},
|
|
]
|
|
responses = asyncio.run(
|
|
generate_from_openai_chat_completion(
|
|
client,
|
|
messages=[messages]*50,
|
|
engine_name="gpt-3.5-turbo-0125",
|
|
max_tokens=256,
|
|
response_format={"type":"json_object"},
|
|
)
|
|
)
|
|
print(responses) |