|
import requests |
|
|
|
from transformers import Tool |
|
from transformers import pipeline |
|
|
|
class TextGenerationTool(Tool): |
|
name = "text_generator" |
|
description = ( |
|
"This is a tool for text generation. It takes a prompt as input and returns the generated text." |
|
) |
|
|
|
inputs = ["text"] |
|
outputs = ["text"] |
|
|
|
|
|
def __call__(self, prompt: str): |
|
|
|
|
|
API_URL = "https://api-inference.huggingface.co/models/lukasdrg/clinical_longformer_same_tokens_220k" |
|
headers = {"Authorization": "Bearer "+os.environ['HF']+"} |
|
|
|
#def query(payload): |
|
generated_text = requests.post(API_URL, headers=headers, json=payload) |
|
# return response.json() |
|
|
|
#output = query({ |
|
# "inputs": "The answer to the universe is <mask>.", |
|
#}) |
|
|
|
|
|
|
|
# Replace the following line with your text generation logic |
|
#generated_text = f"Generated text based on the prompt: '{prompt}'" |
|
|
|
# Initialize the text generation pipeline |
|
#text_generator = pipeline("text-generation") llama mistralai/Mistral-7B-Instruct-v0.1 |
|
#text_generator = pipeline(model="gpt2") |
|
#text_generator = pipeline(model="meta-llama/Llama-2-7b-chat-hf") |
|
|
|
# Generate text based on a prompt |
|
#generated_text = text_generator(prompt, max_length=500, num_return_sequences=1, temperature=0.7) |
|
|
|
# Print the generated text |
|
#print(generated_text) |
|
|
|
|
|
|
|
return generated_text |
|
|
|
|
|
|