Spaces:
Running
Running
File size: 3,681 Bytes
d7545dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import json
import aiohttp
class TextGenerator:
def __init__(self, host_url):
self.host_url = host_url.rstrip("/") + "/generate"
self.host_url_stream = host_url.rstrip("/") + "/generate_stream"
async def generate_text_async(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
payload = {
'inputs': prompt,
'parameters': {
'max_new_tokens': max_new_tokens,
'do_sample': do_sample,
'temperature': temperature,
}
}
headers = {
'Content-Type': 'application/json'
}
async with aiohttp.ClientSession() as session:
async with session.post(self.host_url, data=json.dumps(payload), headers=headers) as response:
if response.status == 200:
data = await response.json()
text = data["generated_text"]
return text
else:
# Handle error responses here
return None
def generate_text(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8):
import requests
payload = {
'inputs': prompt,
'parameters': {
'max_new_tokens': max_new_tokens,
'do_sample': do_sample,
'temperature': temperature,
}
}
headers = {
'Content-Type': 'application/json'
}
response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
text = response["generated_text"]
return text
def generate_text_stream(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8, stop=[], best_of=1):
import requests
payload = {
'inputs': prompt,
'parameters': {
'max_new_tokens': max_new_tokens,
'do_sample': do_sample,
'temperature': temperature,
'stop': stop,
'best_of': best_of,
}
}
headers = {
'Content-Type': 'application/json',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive'
}
response = requests.post(self.host_url_stream, data=json.dumps(payload), headers=headers, stream=True)
for line in response.iter_lines():
if line:
print(line)
json_data = line.decode('utf-8')
if json_data.startswith('data:'):
print(json_data)
json_data = json_data[5:]
token_data = json.loads(json_data)
token = token_data['token']['text']
if not token_data['token']['special']:
yield token
class SummarizerGenerator:
def __init__(self, api):
self.api = api
def generate_summary_stream(self, text):
import requests
payload = {"text": text}
headers = {
'Content-Type': 'application/json',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive'
}
response = requests.post(self.api, data=json.dumps(payload), headers=headers, stream=True)
for line in response.iter_lines():
if line:
print(line)
data = line.decode('utf-8').removesuffix('<|eot_id|>')
if data.startswith("•"):
data = data.replace("•", "-")
if data.startswith("-"):
data = "\n\n" + data
yield data
|