File size: 1,877 Bytes
17e8c28
8e07109
 
8acc27a
17e8c28
8e07109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fbab2d
8e07109
 
 
9fbab2d
8acc27a
8e07109
 
 
 
8acc27a
8e07109
8acc27a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import random

models = ["microsoft/DialoGPT-medium", "facebook/opt-125m"]
tokenizers = {name: AutoTokenizer.from_pretrained(name) for name in models}
clients = {name: AutoModelForCausalLM.from_pretrained(name).to('cpu') for name in models}
ss_client=AutoModelForCausalLM.from_pretrained("nchlt/omnibus-image-current-tab").to('cuda')

def load_models(choice):
    return clients[choice],tokenizers[choice]

def chat_inf(sys, inp, chat, mem, cli, seed, temp, tokens, top_p, rep_p, chat_mem, custom_prompt):
    torch.manual_seed(int(seed))
    if not sys:
        sys = "<|startoftext|>"
    if inp is None:
        return [],[]
    history=[(inp,chat)]
    chat+=[inp]
    response = cli.generate(torch.tensor([tokenizers[cli].encode(f'{sys}: {inp}\n') for inp in chat]), max_length=int(tokens), temperature=float(temp), top_p=float(top_p), do_sample=True, repetition_penalty=float(rep_p))
    res = tokenizers[cli].decode(response[:, -1])
    chat+=[res]
    
    custom_prompt.text = "\n".join([f"{i}: {inp} <--> {res}" for i,(inp,res) in enumerate(history[::-1][:chat_mem])])
    return res, chat

def get_screenshot(cli, im_height, im_width, chatblock, theme, wait_time):
    chat = cli.generate(torch.tensor([tokenizers[cli].encode('<|startoftext|>: '+'\n'.join([inp for i,(inp,res) in enumerate(history[::-1][:chatblock]) if not i%2])) for _ in range(5)]), max_length=int(tokens), temperature=float(temp), top_p=float(top_p), do_sample=True, repetition_penalty=float(rep_p))
    return tokenizers[cli].decode(response[:, -1])

def clear_fn():
    inp.value = ""
    sys_inp.value = ""
    chat_b.value = []
    memory.value = None
    
im_go=im_btn.click(get_screenshot,[chat_b,im_height,im_width,chatblock,theme,wait_time],img)
app.queue(default_concurrency_limit=10).launch()