Spaces:
Running
Running
import json | |
import aiohttp | |
class TextGenerator: | |
def __init__(self, host_url): | |
self.host_url = host_url.rstrip("/") + "/generate" | |
self.host_url_stream = host_url.rstrip("/") + "/generate_stream" | |
async def generate_text_async(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8): | |
payload = { | |
'inputs': prompt, | |
'parameters': { | |
'max_new_tokens': max_new_tokens, | |
'do_sample': do_sample, | |
'temperature': temperature, | |
} | |
} | |
headers = { | |
'Content-Type': 'application/json' | |
} | |
async with aiohttp.ClientSession() as session: | |
async with session.post(self.host_url, data=json.dumps(payload), headers=headers) as response: | |
if response.status == 200: | |
data = await response.json() | |
text = data["generated_text"] | |
return text | |
else: | |
# Handle error responses here | |
return None | |
def generate_text(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8): | |
import requests | |
payload = { | |
'inputs': prompt, | |
'parameters': { | |
'max_new_tokens': max_new_tokens, | |
'do_sample': do_sample, | |
'temperature': temperature, | |
} | |
} | |
headers = { | |
'Content-Type': 'application/json' | |
} | |
response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json() | |
text = response["generated_text"] | |
return text | |
def generate_text_stream(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8, stop=[], best_of=1): | |
import requests | |
payload = { | |
'inputs': prompt, | |
'parameters': { | |
'max_new_tokens': max_new_tokens, | |
'do_sample': do_sample, | |
'temperature': temperature, | |
'stop': stop, | |
'best_of': best_of, | |
} | |
} | |
headers = { | |
'Content-Type': 'application/json', | |
'Cache-Control': 'no-cache', | |
'Connection': 'keep-alive' | |
} | |
response = requests.post(self.host_url_stream, data=json.dumps(payload), headers=headers, stream=True) | |
for line in response.iter_lines(): | |
if line: | |
print(line) | |
json_data = line.decode('utf-8') | |
if json_data.startswith('data:'): | |
print(json_data) | |
json_data = json_data[5:] | |
token_data = json.loads(json_data) | |
token = token_data['token']['text'] | |
if not token_data['token']['special']: | |
yield token | |
class SummarizerGenerator: | |
def __init__(self, api): | |
self.api = api | |
def generate_summary_stream(self, text): | |
import requests | |
payload = {"text": text} | |
headers = { | |
'Content-Type': 'application/json', | |
'Cache-Control': 'no-cache', | |
'Connection': 'keep-alive' | |
} | |
response = requests.post(self.api, data=json.dumps(payload), headers=headers, stream=True) | |
for line in response.iter_lines(): | |
if line: | |
print(line) | |
data = line.decode('utf-8').removesuffix('<|eot_id|>') | |
if data.startswith("•"): | |
data = data.replace("•", "-") | |
if data.startswith("-"): | |
data = "\n\n" + data | |
yield data | |