Spaces:
Sleeping
Sleeping
File size: 5,484 Bytes
5b06045 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import urllib.request
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
import huggingface_hub
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time
import transformers
import requests
import globals
from utility import *
"""set up"""
huggingface_hub.login(token=globals.HF_TOKEN)
gemma_tokenizer = AutoTokenizer.from_pretrained(globals.gemma_2b_URL)
gemma_model = AutoModelForCausalLM.from_pretrained(globals.gemma_2b_URL)
falcon_tokenizer = AutoTokenizer.from_pretrained(globals.falcon_7b_URL, trust_remote_code=True, device_map=globals.device_map, offload_folder="offload")
falcon_model = AutoModelForCausalLM.from_pretrained(globals.falcon_7b_URL, trust_remote_code=True,
torch_dtype=torch.bfloat16, device_map=globals.device_map, offload_folder="offload")
def get_model(model_typ):
if model_typ not in ["gemma", "falcon", "falcon_api", "simplet5_base", "simplet5_large"]:
raise ValueError('Invalid model type. Choose "gemma", "falcon", "falcon_api","simplet5_base", "simplet5_large".')
if model_typ=="gemma":
tokenizer = gemma_tokenizer
model = gemma_model
prefix = globals.gemma_PREFIX
elif model_typ=="falcon_api":
prefix = globals.falcon_PREFIX
model=None
tokenizer = None
elif model_typ=="falcon":
tokenizer = falcon_tokenizer
model = falcon_model
prefix = globals.falcon_PREFIX
elif model_typ in ["simplet5_base","simplet5_large"]:
prefix = globals.simplet5_PREFIX
URL = globals.simplet5_base_URL if model_typ=="simplet5_base" else globals.simplet5_large_URL
T5_MODEL_PATH = f"https://huggingface.co/{URL}/resolve/main/{globals.T5_FILE_NAME}"
fetch_model(T5_MODEL_PATH, globals.T5_FILE_NAME)
tokenizer = T5Tokenizer.from_pretrained(URL)
model = T5ForConditionalGeneration.from_pretrained(URL)
return model, tokenizer, prefix
def topk_query(model_typ="gemma",prompt="She has a heart of gold",temperature=0.7,max_length=256):
if model_typ not in ["gemma","simplet5_base", "simplet5_large"]:
raise ValueError('Invalid model type. Choose "gemma", "simplet5_base", "simplet5_large".')
model, tokenizer, prefix = get_model(model_typ)
start_time = time.time()
input = prefix.replace("{fig}", prompt)
print(f"Input to model: \n{input}")
if model_typ in ["simplet5_base", "simplet5_large"]:
inputs = tokenizer(input, return_tensors="pt")
outputs = model.generate(
inputs["input_ids"],
temperature=temperature,
max_length=max_length,
num_beams=5,
num_return_sequences=5, # Generate 5 responses
early_stopping=True
)
response = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
answer = [response.replace(input, "").strip() for response in response]
elif model_typ=="gemma":
inputs = tokenizer(input, return_tensors="pt")
generate_ids = gemma_model.generate(
inputs.input_ids,
max_length=max_length,
do_sample=True,
top_k=50,
temperature=temperature,
num_return_sequences=5,
eos_token_id=gemma_tokenizer.eos_token_id
)
outputs = gemma_tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(f"Model original output:{outputs}\n")
answer = [post_process(output,input).replace("\n", "") for output in outputs]
# TODO: falcon's outputs dont have much differences, not used in topk response
# elif model_typ=="falcon_api":
# API_URL = "https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct"
# headers = {"Authorization": f"Bearer {access_token}"}
# response = api_query(API_URL=API_URL, headers=headers, payload={
# "inputs": input,
# "parameters": {
# "temperature": temperature,
# "top_k": 50,
# "num_return_sequences": 5
# }
# })
# print(response)
# answer = [post_process(item["generated_text"], input) for item in response]
else:
raise ValueError('Invalid model type. Choose "gemma", "simplet5_base", "simplet5_large".')
print(f"Time taken: {time.time()-start_time:.2f} seconds")
print(f"processed model output: {answer}")
return answer
topk_iface = gr.Interface(
fn=topk_query,
inputs=[
gr.Dropdown(
choices=["gemma", "simplet5_base", "simplet5_large"],
label="Model Type",
value="gemma"
),
gr.Textbox(label="Prompt", placeholder="Enter your prompt here"),
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature"),
gr.Slider(minimum=50, maximum=512, step=10, value=256, label="Max Length")
],
outputs=[
gr.Textbox(label="Response 1"),
gr.Textbox(label="Response 2"),
gr.Textbox(label="Response 3"),
gr.Textbox(label="Response 4"),
gr.Textbox(label="Response 5")
],theme=gr.themes.Soft(),
title=globals.TITLE,
description="Generate multiple responses (top 5) based on input sentence, prefix, and temperature. Literal meanings/explanations are provided based on the input figurative sentence.",
examples=[
["gemma", "Time flies when you're having fun",0.7],
["simplet5_large", "She has a heart of gold",0.5],
["gemma", "The sky is the limit",0.6]
]
)
if __name__ == '__main__':
topk_iface.launch() |