import json import aiohttp class TextGenerator: def __init__(self, host_url): self.host_url = host_url.rstrip("/") + "/generate" self.host_url_stream = host_url.rstrip("/") + "/generate_stream" async def generate_text_async(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8): payload = { 'inputs': prompt, 'parameters': { 'max_new_tokens': max_new_tokens, 'do_sample': do_sample, 'temperature': temperature, } } headers = { 'Content-Type': 'application/json' } async with aiohttp.ClientSession() as session: async with session.post(self.host_url, data=json.dumps(payload), headers=headers) as response: if response.status == 200: data = await response.json() text = data["generated_text"] return text else: # Handle error responses here return None def generate_text(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8): import requests payload = { 'inputs': prompt, 'parameters': { 'max_new_tokens': max_new_tokens, 'do_sample': do_sample, 'temperature': temperature, } } headers = { 'Content-Type': 'application/json' } response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json() text = response["generated_text"] return text def generate_text_stream(self, prompt, max_new_tokens=100, do_sample=True, temperature=0.8, stop=[], best_of=1): import requests payload = { 'inputs': prompt, 'parameters': { 'max_new_tokens': max_new_tokens, 'do_sample': do_sample, 'temperature': temperature, 'stop': stop, 'best_of': best_of, } } headers = { 'Content-Type': 'application/json', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' } response = requests.post(self.host_url_stream, data=json.dumps(payload), headers=headers, stream=True) for line in response.iter_lines(): if line: print(line) json_data = line.decode('utf-8') if json_data.startswith('data:'): print(json_data) json_data = json_data[5:] token_data = json.loads(json_data) token = token_data['token']['text'] if not token_data['token']['special']: yield token class SummarizerGenerator: def __init__(self, api): self.api = api def generate_summary_stream(self, text): import requests payload = {"text": text} headers = { 'Content-Type': 'application/json', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' } response = requests.post(self.api, data=json.dumps(payload), headers=headers, stream=True) for line in response.iter_lines(): if line: print(line) data = line.decode('utf-8').removesuffix('<|eot_id|>') if data.startswith("•"): data = data.replace("•", "-") if data.startswith("-"): data = "\n\n" + data yield data