File size: 1,320 Bytes
1e156c7
52631c2
9947a95
 
fb4cef0
6fe5c25
9947a95
41ca296
9947a95
 
c867b0d
9947a95
 
 
 
fb4cef0
 
9947a95
 
 
 
 
 
 
fb4cef0
 
9947a95
8f714ea
52631c2
9947a95
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import gradio as gr
from typing import List
import json
from Classes.Owiki_Class import OWiki
from Classes.run_local_LLM import LocalLLM

def predict(query :str , chat_history : List[tuple[str,str]] , invocation_type : str = "OIC" , schemas : dict= {}) -> str:
    with open("config.json",'r') as f:
        hyperparameters = json.load(f)
    a = OWiki(**hyperparameters)
    local_llm = LocalLLM(**hyperparameters)
    if invocation_type =="SQL":
        try:
            res = a.create_sql_agent(query,schemas)
        except Exception as e:
            res = local_llm.predict(query, invocation_type, schemas)
            return res or "Due to limited compute, I am unable to answer at this moment. Please upgrade your deployment space."  
    elif invocation_type == "OIC":
        try:
            chat = ""
            for user,bot in chat_history:
                chat+= f"User: {user} Bot: {bot}\n\n"
            res = a.search_from_db(query, chat)
        except Exception as e:
            res = local_llm.predict(query, invocation_type, schemas)
            return res or "Due to limited compute, I am unable to answer at this moment. Please upgrade your deployment space."  
    return res


iface = gr.Interface(fn = predict,inputs = ["text","list","text","json"],outputs = "text")
iface.launch(debug=True)