Spaces:
Sleeping
Sleeping
File size: 2,810 Bytes
10625fe e69b8d9 5d53afe 10625fe e69b8d9 c45d8c5 e69b8d9 35d86be 19ed2fd 35d86be e69b8d9 35d86be e69b8d9 19ed2fd 35d86be e69b8d9 35d86be e69b8d9 4f95c8e 4d9e68c 35d86be 19ed2fd 35d86be 19ed2fd 35d86be eab490f 4f95c8e 10625fe 35d86be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer
import json
# Inicialize o cliente e o tokenizador
model_name = "rss9051/autotrein-BERT-iiLEX-dgs-0005"
client = InferenceClient(model=model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Função para dividir texto em chunks com truncamento garantido
def split_text_into_chunks(text, max_tokens=512):
tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"][0]
chunks = []
for i in range(0, len(tokens), max_tokens):
chunk = tokens[i:i + max_tokens]
if len(chunk) > max_tokens:
chunk = chunk[:max_tokens] # Truncar qualquer excesso
chunks.append(chunk)
return [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
# Função para classificar texto longo
def classify_text(text):
chunks = split_text_into_chunks(text, max_tokens=512) # Divida o texto em chunks menores
all_responses = [] # Lista para armazenar respostas de cada chunk
for chunk in chunks:
try:
response_bytes = client.post(json={"inputs": chunk}) # Enviar o chunk
response_str = response_bytes.decode('utf-8') # Decodificar de bytes para string
response = json.loads(response_str) # Converter string JSON para objeto Python
if isinstance(response, list) and len(response) > 0:
sorted_response = sorted(response[0], key=lambda x: x['score'], reverse=True)
all_responses.append(sorted_response[0]) # Adicionar a melhor classificação do chunk
except Exception as e:
print(f"Erro ao processar chunk: {e}")
# Combinar resultados de todos os chunks
if all_responses:
# Contar as classes mais frequentes
class_scores = {}
for res in all_responses:
label = res['label']
score = res['score']
if label in class_scores:
class_scores[label] += score
else:
class_scores[label] = score
# Obter a classe com maior score combinado
predicted_class = max(class_scores, key=class_scores.get)
else:
predicted_class = "Classificação não encontrada"
return predicted_class
# Interface Gradio
demo = gr.Interface(
fn=classify_text, # Função a ser chamada para classificar o texto
inputs=gr.Textbox(label="Texto para Classificação"), # Entrada de texto
outputs=gr.Label(label="Classe Predita"), # Saída da classificação
title="Classificador de Texto", # Título da interface
description="Insira um texto para obter a classificação usando o modelo treinado." # Descrição da interface
)
if __name__ == "__main__":
demo.launch() |