ytseg_demo / generate_text_api.py
retkowski's picture
Handle token by token generation
d7545dc
raw
history blame
3.68 kB
import json
import aiohttp
class TextGenerator:
def __init__(self, host_url):
self.host_url = host_url.rstrip("/") + "/generate"
self.host_url_stream = host_url.rstrip("/") + "/generate_stream"
async def generate_text_async(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
payload = {
'inputs': prompt,
'parameters': {
'max_new_tokens': max_new_tokens,
'do_sample': do_sample,
'temperature': temperature,
}
}
headers = {
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as session:
async with session.post(self.host_url, data=json.dumps(payload), headers=headers) as response:
if response.status == 200:
data = await response.json()
text = data["generated_text"]
return text
else:
# Handle error responses here
return None
def generate_text(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
import requests
payload = {
'inputs': prompt,
'parameters': {
'max_new_tokens': max_new_tokens,
'do_sample': do_sample,
'temperature': temperature,
}
}
headers = {
'Content-Type': 'application/json'
}
response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
text = response["generated_text"]
return text
def generate_text_stream(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8, stop=[], best_of=1):
import requests
payload = {
'inputs': prompt,
'parameters': {
'max_new_tokens': max_new_tokens,
'do_sample': do_sample,
'temperature': temperature,
'stop': stop,
'best_of': best_of,
}
}
headers = {
'Content-Type': 'application/json',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive'
}
response = requests.post(self.host_url_stream, data=json.dumps(payload), headers=headers, stream=True)
for line in response.iter_lines():
if line:
print(line)
json_data = line.decode('utf-8')
if json_data.startswith('data:'):
print(json_data)
json_data = json_data[5:]
token_data = json.loads(json_data)
token = token_data['token']['text']
if not token_data['token']['special']:
yield token
class SummarizerGenerator:
def __init__(self, api):
self.api = api
def generate_summary_stream(self, text):
import requests
payload = {"text": text}
headers = {
'Content-Type': 'application/json',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive'
}
response = requests.post(self.api, data=json.dumps(payload), headers=headers, stream=True)
for line in response.iter_lines():
if line:
print(line)
data = line.decode('utf-8').removesuffix('<|eot_id|>')
if data.startswith("•"):
data = data.replace("•", "-")
if data.startswith("-"):
data = "\n\n" + data
yield data