gemmaw / app.py
Yahir's picture
Update app.py
8e07109 verified
raw
history blame
1.88 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import random
models = ["microsoft/DialoGPT-medium", "facebook/opt-125m"]
tokenizers = {name: AutoTokenizer.from_pretrained(name) for name in models}
clients = {name: AutoModelForCausalLM.from_pretrained(name).to('cpu') for name in models}
ss_client=AutoModelForCausalLM.from_pretrained("nchlt/omnibus-image-current-tab").to('cuda')
def load_models(choice):
return clients[choice],tokenizers[choice]
def chat_inf(sys, inp, chat, mem, cli, seed, temp, tokens, top_p, rep_p, chat_mem, custom_prompt):
torch.manual_seed(int(seed))
if not sys:
sys = "<|startoftext|>"
if inp is None:
return [],[]
history=[(inp,chat)]
chat+=[inp]
response = cli.generate(torch.tensor([tokenizers[cli].encode(f'{sys}: {inp}\n') for inp in chat]), max_length=int(tokens), temperature=float(temp), top_p=float(top_p), do_sample=True, repetition_penalty=float(rep_p))
res = tokenizers[cli].decode(response[:, -1])
chat+=[res]
custom_prompt.text = "\n".join([f"{i}: {inp} <--> {res}" for i,(inp,res) in enumerate(history[::-1][:chat_mem])])
return res, chat
def get_screenshot(cli, im_height, im_width, chatblock, theme, wait_time):
chat = cli.generate(torch.tensor([tokenizers[cli].encode('<|startoftext|>: '+'\n'.join([inp for i,(inp,res) in enumerate(history[::-1][:chatblock]) if not i%2])) for _ in range(5)]), max_length=int(tokens), temperature=float(temp), top_p=float(top_p), do_sample=True, repetition_penalty=float(rep_p))
return tokenizers[cli].decode(response[:, -1])
def clear_fn():
inp.value = ""
sys_inp.value = ""
chat_b.value = []
memory.value = None
im_go=im_btn.click(get_screenshot,[chat_b,im_height,im_width,chatblock,theme,wait_time],img)
app.queue(default_concurrency_limit=10).launch()