fschwartzer commited on
Commit
b531f3a
·
verified ·
1 Parent(s): 76302e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import pandas as pd
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModel
4
  import torch
5
 
6
- tokenizer = AutoTokenizer.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')
7
- model = AutoModel.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')
8
 
9
  # Dados iniciais
10
  data = {
@@ -32,11 +32,12 @@ def get_gpt_response(query):
32
  {csv_data}
33
 
34
  """
35
- input_ids = tokenizer.encode(query, return_tensors='pt')
36
- max_length = input_ids.shape[1] + 100
37
- generated_ids = model.generate(input_ids, max_length=max_length)
38
- generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
39
- return generated_text
 
40
 
41
 
42
  def ask_question(pergunta):
 
1
  import pandas as pd
2
  import gradio as gr
3
+ from transformers import BertTokenizer, BertForSequenceClassification
4
  import torch
5
 
6
+ tokenizer = BertTokenizer.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')
7
+ model = BertForSequenceClassification.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')
8
 
9
  # Dados iniciais
10
  data = {
 
32
  {csv_data}
33
 
34
  """
35
+ inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
36
+ outputs = model(**inputs)
37
+ prediction = torch.argmax(outputs.logits, dim=1)
38
+ labels = ["Label1", "Label2"] # Substitua com suas etiquetas reais
39
+ predicted_label = labels[prediction]
40
+ return predicted_label
41
 
42
 
43
  def ask_question(pergunta):