import gradio as gr
import base64
import random
import gradio as gr
#import urllib.request
import requests
import bs4
import lxml
import os
#import subprocess
from huggingface_hub import InferenceClient,HfApi
import random
import json
import datetime
from pypdf import PdfReader
import uuid
#from query import tasks
from gradio_client import Client

from agent import (
    PREFIX,
    GET_CHART,
    COMPRESS_DATA_PROMPT,
    COMPRESS_DATA_PROMPT_SMALL,
    LOG_PROMPT,
    LOG_RESPONSE,
)
api=HfApi()

client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")


def sort_fn(inp):

    client_sort = Client("Omnibus/sort_document")
    sen,nouns = client_sort.predict(
    		f"{inp}",	# str  in 'Paste Text' Textbox component
    		api_name="/sort_doc"
    )
    return nouns

def find_all(url):
    return_list=[]
    print (url)
    #if action_input in query.tasks:
    print (f"trying URL:: {url}")        
    try:
        if url != "" and url != None:    
            out = []
            source = requests.get(url)
            #source = urllib.request.urlopen(url).read()
            soup = bs4.BeautifulSoup(source.content,'lxml')

            rawp=(f'RAW TEXT RETURNED: {soup.text}')
            cnt=0
            cnt+=len(rawp)
            out.append(rawp)
            out.append("HTML fragments: ")
            q=("a","p","span","content","article")
            for p in soup.find_all("a"):
                out.append([{"LINK TITLE":p.get('title'),"URL":p.get('href'),"STRING":p.string}])
  
            print(rawp)
            return True, rawp
        else: 
            return False, "Enter Valid URL"
    except Exception as e:
        print (e)
        return False, f'Error: {e}'

        #else:
    #    history = "observation: The search query I used did not return a valid response"
        
    return "MAIN", None, history, task

FIND_KEYWORDS="""Find keywords from the dictionary of provided keywords that are relevant to the users query.
Return the keyword:value pairs from the list in the form of a JSON file output.
dictionary:
{keywords}
user query:
"""

def find_keyword_fn(c,inp,data):

    data=str(data)
    seed=random.randint(1,1000000000)
    divr=int(c)/20000
    divi=int(divr)+1 if divr != int(divr) else int(divr)
    chunk = int(int(c)/divr)
    out = []
    s=0
    e=chunk
    print(f'e:: {e}')
    #task = f'Compile this data to fulfill the task: {task}, and complete the purpose: {purpose}\n'
    for z in range(divi):
        print(f's:e :: {s}:{e}')
        
        hist = data[s:e]
        resp = run_gpt(
            FIND_KEYWORDS,
            stop_tokens=[],
            max_tokens=8192,
            seed=seed,
            keywords=data,
        ).strip("\n")    
        out.append(resp)
        #new_history = resp
        print (resp)
        #out+=resp
        e=e+chunk
        s=s+chunk    
    return out 
    

def read_txt(txt_path):
    text=""
    with open(txt_path,"r") as f:
        text = f.read()
    f.close()
    print (text)
    return text

def read_pdf(pdf_path):
    text=""
    reader = PdfReader(f'{pdf_path}')
    number_of_pages = len(reader.pages)
    for i in range(number_of_pages):
        page = reader.pages[i]
        text = f'{text}\n{page.extract_text()}'
    print (text)
    return text

error_box=[]
def read_pdf_online(url):
    uid=uuid.uuid4()
    print(f"reading {url}")
    response = requests.get(url, stream=True)
    print(response.status_code)
    text=""
#################
    
#####################
    try:
        if response.status_code == 200:
            with open("test.pdf", "wb") as f:
                f.write(response.content)
            #f.close()
            #out = Path("./data.pdf")
            #print (out)
            reader = PdfReader("test.pdf")
            number_of_pages = len(reader.pages)
            print(number_of_pages)
            for i in range(number_of_pages):
                page = reader.pages[i]
                text = f'{text}\n{page.extract_text()}'
                print(f"PDF_TEXT:: {text}")
            return text
        else:
            text = response.status_code
            error_box.append(url)
            print(text)
            return text


    except Exception as e:
        print (e)
        return e


VERBOSE = True
MAX_HISTORY = 100
MAX_DATA = 20000

def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def run_gpt_no_prefix(
    prompt_template,
    stop_tokens,
    max_tokens,
    seed,
    **prompt_kwargs,
):
    print(seed)
    try:
        generate_kwargs = dict(
            temperature=0.9,
            max_new_tokens=max_tokens,
            top_p=0.95,
            repetition_penalty=1.0,
            do_sample=True,
            seed=seed,
        )
        
        content = prompt_template.format(**prompt_kwargs)
        #if VERBOSE:
        print(LOG_PROMPT.format(content))
        
        
        #formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
        #formatted_prompt = format_prompt(f'{content}', history)
    
        stream = client.text_generation(content, **generate_kwargs, stream=True, details=True, return_full_text=False)
        resp = ""
        for response in stream:
            resp += response.token.text
            #yield resp
    
        #if VERBOSE:
        print(LOG_RESPONSE.format(resp))
        return resp
    except Exception as e:
        print(f'no_prefix_error:: {e}')
        return "Error"
def run_gpt(
    prompt_template,
    stop_tokens,
    max_tokens,
    seed,
    **prompt_kwargs,
):
    print(seed)
    timestamp=datetime.datetime.now()
    
    generate_kwargs = dict(
        temperature=0.9,
        max_new_tokens=max_tokens,
        top_p=0.95,
        repetition_penalty=1.0,
        do_sample=True,
        seed=seed,
    )
    
    content = PREFIX.format(
        timestamp=timestamp,
        purpose="Compile the provided data and complete the users task"
    ) + prompt_template.format(**prompt_kwargs)
    #if VERBOSE:
    print(LOG_PROMPT.format(content))
    
    
    #formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    #formatted_prompt = format_prompt(f'{content}', history)

    stream = client.text_generation(content, **generate_kwargs, stream=True, details=True, return_full_text=False)
    resp = ""
    for response in stream:
        resp += response.token.text
        #yield resp

    if VERBOSE:
        print(LOG_RESPONSE.format(resp))
    return resp

    
def compress_data(c, instruct, history):
    seed=random.randint(1,1000000000)
    
    print (c)
    #tot=len(purpose)
    #print(tot)
    divr=int(c)/MAX_DATA
    divi=int(divr)+1 if divr != int(divr) else int(divr)
    chunk = int(int(c)/divr)
    print(f'chunk:: {chunk}')
    print(f'divr:: {divr}')
    print (f'divi:: {divi}')
    out = []
    #out=""
    s=0
    e=chunk
    print(f'e:: {e}')
    new_history=""
    #task = f'Compile this data to fulfill the task: {task}, and complete the purpose: {purpose}\n'
    for z in range(divi):
        print(f's:e :: {s}:{e}')
        
        hist = history[s:e]
        
        resp = run_gpt(
            COMPRESS_DATA_PROMPT_SMALL,
            stop_tokens=["observation:", "task:", "action:", "thought:"],
            max_tokens=8192,
            seed=seed,
            direction=instruct,
            knowledge="",
            history=hist,
        ).strip("\n")
        out.append(resp)
        #new_history = resp
        print (resp)
        #out+=resp
        e=e+chunk
        s=s+chunk
    return out

    
def compress_data_og(c, instruct, history):
    seed=random.randint(1,1000000000)
    
    print (c)
    #tot=len(purpose)
    #print(tot)
    divr=int(c)/MAX_DATA
    divi=int(divr)+1 if divr != int(divr) else int(divr)
    chunk = int(int(c)/divr)
    print(f'chunk:: {chunk}')
    print(f'divr:: {divr}')
    print (f'divi:: {divi}')
    out = []
    #out=""
    s=0
    e=chunk
    print(f'e:: {e}')
    new_history=""
    #task = f'Compile this data to fulfill the task: {task}, and complete the purpose: {purpose}\n'
    for z in range(divi):
        print(f's:e :: {s}:{e}')
        
        hist = history[s:e]
        
        resp = run_gpt(
            COMPRESS_DATA_PROMPT,
            stop_tokens=["observation:", "task:", "action:", "thought:"],
            max_tokens=8192,
            seed=seed,
            direction=instruct,
            knowledge=new_history,
            history=hist,
        ).strip("\n")
        
        new_history = resp
        print (resp)
        out+=resp
        e=e+chunk
        s=s+chunk
    '''
    resp = run_gpt(
        COMPRESS_DATA_PROMPT,
        stop_tokens=["observation:", "task:", "action:", "thought:"],
        max_tokens=8192,
        seed=seed,
        direction=instruct,
        knowledge=new_history,
        history="All data has been recieved.",
    )'''
    print ("final" + resp)
    #history = "observation: {}\n".format(resp)
    return resp

def get_chart(inp):
    seed=random.randint(1,1000000000)
    try:
        resp = run_gpt_no_prefix(
            GET_CHART,
            stop_tokens=[],
            max_tokens=8192,
            seed=seed,
            inp=inp,
        ).strip("\n")
        print(resp)
    except Exception as e:
        print(f'Error:: {e}')
        resp = e
    return resp

def format_json(inp):

    print("FORMATTING:::")
    print(type(inp))
    print("###########")
    print(inp)
    print("###########")
    print("###########")
    new_str=""
    matches=["```","#","//"]
    for i,line in enumerate(inp):
        line = line.strip()
        print(line)
        #if not any(x in line for x in matches):
        new_str+=line.strip("\n").strip("```").strip("#").strip("//")
    print("###########")
    print("###########")
    #inp = inp.strip("<\s>")
    new_str=new_str.strip("</s>")
    out_json=eval(new_str)
    print(out_json)
    print("###########")
    print("###########")
    
    return out_json


this=["1.25"]
css="""
#wrap { width: 100%; height: 100%; padding: 0; overflow: auto; }
#frame { width: 100%; border: 1px solid black; }
#frame { zoom: $ZOOM; -moz-transform: scale($ZOOM); -moz-transform-origin: 0 0; }
"""


def mm(graph,zoom):

    code_out=""
    for ea in graph.split("\n"):
        code=ea.strip().strip("\n")
        code_out+=code
    #out_html=f'''<div><iframe src="https://omnibus-mermaid-script.static.hf.space/index.html?mermaid={code_out}&rand={random.randint(1,1111111111)}" height="500" width="500"></iframe></div>'''
    out_html=f'''<div id="wrap" style="width: 100%; height: 100%;max-height:600px; padding: 0; overflow: auto;"><iframe id="frame" src="https://omnibus-mermaid-script.static.hf.space/index.html?mermaid={code_out}" style="border: 1px solid black; zoom: {str(zoom)}; -moz-transform: scale({str(zoom)}); -moz-transform-origin: 0 0;"></iframe></div>'''
    return out_html
    
def summarize(inp,history,data=None,files=None,directory=None,url=None,pdf_url=None,pdf_batch=None):
    json_box=[]
    chart_out=""
    if inp == "":
        inp = "Process this data"
    history.clear()
    history = [(inp,"Working on it...")] 
    yield "",history,chart_out,chart_out,json_box

    if pdf_batch.startswith("http"):
        lab="PDF Batch"
        c=0
        data=""
        for i in str(pdf_batch):
            if i==",":
                c+=1
        print (f'c:: {c}')

        try:
            for i in range(c+1):
                batch_url = pdf_batch.split(",",c)[i]
                bb = read_pdf_online(batch_url)
                data=f'{data}\nFile Name URL ({batch_url}):\n{bb}'
        except Exception as e:
            print(e)
            #data=f'{data}\nError reading URL ({batch_url})'
            
    if directory:
        lab="Directory"
        
        for ea in directory:
            print(ea)
        
    if pdf_url.startswith("http"):
        lab="PDF URL"
        
        print("PDF_URL")
        out = read_pdf_online(pdf_url)
        data=out
    if url.startswith("http"):
        lab="Raw HTML"
        
        val, out = find_all(url)
        if not val:
            data="Error"
            rawp = str(out)
        else:
            data=out
    if files:
        lab="Files"
        
        for i, file in enumerate(files):
            try: 
                print (file)
                if file.endswith(".pdf"):
                    zz=read_pdf(file)
                    print (zz)
                    data=f'{data}\nFile Name ({file}):\n{zz}'
                elif file.endswith(".txt"):
                    zz=read_txt(file)
                    print (zz)
                    data=f'{data}\nFile Name ({file}):\n{zz}'                
            except Exception as e:
                data=f'{data}\nError opening File Name ({file})'                
                print (e)

    
    if data != "Error" and data != "":
        history.clear()
        history = [(inp,f"Data: Loaded, processing...")] 
        yield "",history,chart_out,chart_out,json_box
        
        print(inp)
        out = str(data)
        rl = len(out)
        print(f'rl:: {rl}')
        c=1
        for i in str(out):
            if i == " " or i=="," or i=="\n":
                c +=1
        print (f'c:: {c}')
        json_out = compress_data(c,inp,out)  
        out = str(json_out)
        try:
            json_out=format_json(json_out)
        except Exception as e:
            print (e)
        history.clear()
        history = [(inp,json_out),(None,"Building Chart...")] 
        yield "",history,chart_out,chart_out,json_box

            
        chart_out = get_chart(str(json_out))
        chart_list=chart_out.split("\n")
        go=True
        cnti=1
        line_out=""
        for ii, line in enumerate(chart_list):
            if go:
                line=line.strip().replace('"',"")
                if "```" in chart_list[ii]:
                    while True:
                        line_out+=chart_list[ii+cnti].strip().strip("\n").replace('"',"")
                        if not line_out.endswith(";"):
                            line_out+=";"
                        line_out+="\n"
                        cnti+=1
                        if "```" in chart_list[ii+cnti]:
                            go=False
                            break
                    
        
        chart_html=mm(line_out,1)
        #print(chart_out)
    else:
        rawp = "Provide a valid data source"
    history.clear()
    history.append((inp,chart_out))
    yield "", history,chart_html,line_out,json_out

#################################
def clear_fn():
    return "",[(None,None)]



def zoom_update(inp):
    this.clear()
    this.append(str(inp))
    return gr.update()

with gr.Blocks() as app:
    gr.HTML("""<center><h1>Text -to- Chart</h1><h3>Mixtral 8x7B</h3>""")
    chatbot = gr.Chatbot(label="Mixtral 8x7B Chatbot",show_copy_button=True)
    with gr.Row():
        with gr.Column(scale=3):
            prompt=gr.Textbox(label = "Instructions (optional)")
        with gr.Column(scale=1):
            
            button=gr.Button()
        
        #models_dd=gr.Dropdown(choices=[m for m in return_list],interactive=True)
    with gr.Row():
        stop_button=gr.Button("Stop")
        clear_btn = gr.Button("Clear")
    with gr.Row():
        with gr.Tab("Text"):
            data=gr.Textbox(label="Input Data (paste text)", lines=6)
        with gr.Tab("File"):
            file=gr.Files(label="Input File(s) (.pdf .txt)")
        with gr.Tab("Folder"):
            directory=gr.File(label="Folder", file_count='directory')            
        with gr.Tab("Raw HTML"):
            url = gr.Textbox(label="URL")
        with gr.Tab("PDF URL"):
            pdf_url = gr.Textbox(label="PDF URL")       
        with gr.Tab("PDF Batch"):
            pdf_batch = gr.Textbox(label="PDF URL Batch (comma separated)")
    m_box=gr.HTML()
    zoom_btn=gr.Slider(label="Zoom",step=0.01,minimum=0.1,maximum=20,value=1,interactive=True)
    e_box=gr.Textbox(interactive=True)
    upd_button=gr.Button("Update Chart")
    json_out=gr.JSON()
    #text=gr.JSON()

    #zoom_btn.change(zoom_update,zoom_btn,None)
    upd_button.click(mm,[e_box,zoom_btn],[m_box])
    #inp_query.change(search_models,inp_query,models_dd)
    clear_btn.click(clear_fn,None,[prompt,chatbot])
    
    #go=button.click(summarize,[prompt,chatbot,report_check,chart_check,data,file,directory,url,pdf_url,pdf_batch],[prompt,chatbot,e_box,json_out])
    go=button.click(summarize,[prompt,chatbot,data,file,directory,url,pdf_url,pdf_batch],[prompt,chatbot,m_box,e_box,json_out])
    
    stop_button.click(None,None,None,cancels=[go])
app.queue(default_concurrency_limit=20).launch(show_api=False)