claim-detect / app.py
dpaul93's picture
Update app.py
0076d51 verified
raw
history blame
2.12 kB
import gradio as gr
import pandas as pd
import json
import os
from pprint import pprint
import bitsandbytes as bnb
import torch
import torch.nn as nn
import transformers
from datasets import load_dataset, Dataset
from huggingface_hub import notebook_login
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
print("claim")
PEFT_MODEL = "dpaul93/falcon-7b-qlora-chat-claim-data" #"/content/trained-model"
config = PeftConfig.from_pretrained(PEFT_MODEL)
config.base_model_name_or_path = "tiiuae/falcon-7b"
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
tokenizer=AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
model = PeftModel.from_pretrained(model, PEFT_MODEL)
def generate_test_prompt(text):
return f"""Given the following claim:
{data_point["claim"]}.
pick one of the following option
(a) true
(b) false
(c) mixture
(d) unknown
(e) not_applicable?""".strip()
def generate_and_tokenize_prompt(text):
prompt = generate_test_prompt(text)
device = "cuda"
encoding = tokenizer(prompt, return_tensors="pt").to(device)
with torch.inference_mode():
outputs = model.generate(
input_ids = encoding.input_ids,
attention_mask = encoding.attention_mask,
generation_config = generation_config
)
return tokenizer.decode(outputs[0], skip_special_tokens=True).split("Answer:")[1].split("\n")[0].split(".")[0]
def classifyUsingLLAMA(text):
return generate_and_tokenize_prompt(text)
iface = gr.Interface(fn=classifyUsingLLAMA, inputs="text", outputs="text")
iface.launch()