|
import os |
|
import requests |
|
|
|
from transformers import Tool |
|
|
|
|
|
class TextGenerationTool(Tool): |
|
name = "text_generator" |
|
description = ( |
|
"This is a tool for text generation. It takes a prompt as input and returns the generated text." |
|
) |
|
|
|
inputs = ["text"] |
|
outputs = ["text"] |
|
|
|
def __call__(self, prompt: str): |
|
API_URL = "https://api-inference.huggingface.co/models/lukasdrg/clinical_longformer_same_tokens_220k" |
|
headers = {"Authorization": "Bearer " + os.environ['HF']} |
|
|
|
|
|
payload = { |
|
"inputs": prompt |
|
} |
|
|
|
|
|
generated_text = requests.post(API_URL, headers=headers, json=payload).json() |
|
|
|
|
|
return generated_text["generated_text"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|