import re
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from datetime import date


access_token = "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"

base_model = AutoModelForCausalLM.from_pretrained(
    'meta-llama/Llama-2-7b-chat-hf',
    token=access_token,
    trust_remote_code=True, 
    device_map="auto",
    offload_folder="offload/"
)
model = PeftModel.from_pretrained(
    base_model,
    'FinGPT/fingpt-forecaster_dow30_llama2-7b_lora',
    offload_folder="offload/"
)
model = model.eval()

tokenizer = AutoTokenizer.from_pretrained(
    'meta-llama/Llama-2-7b-chat-hf',
    token=access_token
)


def construct_prompt(ticker, date, n_weeks):
    
    return ", ".join([ticker, date, str(n_weeks)])


def get_curday():
    
    return date.today().strftime("%Y-%m-%d")


def predict(ticker, date, n_weeks):
    
    prompt = construct_prompt(ticker, date, n_weeks)
      
#     inputs = tokenizer(
#         prompt, return_tensors='pt',
#         padding=False, max_length=4096
#     )
#     inputs = {key: value.to(model.device) for key, value in inputs.items()}
        
#     res = model.generate(
#         **inputs, max_length=4096, do_sample=True,
#         eos_token_id=tokenizer.eos_token_id,
#         use_cache=True
#     )
#     output = tokenizer.decode(res[0], skip_special_tokens=True)
#     answer = re.sub(r'.*\[/INST\]\s*', '', output, flags=re.DOTALL)

    answer = prompt
    
    return answer


demo = gr.Interface(
    predict,
    inputs=[
        gr.Textbox(
            label="Ticker",
            value="AAPL",
            info="Companys from Dow-30 are recommended"
        ),
        gr.Textbox(
            label="Date",
            value=get_curday,
            info="Date from which the prediction is made, use format yyyy-mm-dd"
        ),
        gr.Slider(
            minimum=1,
            maximum=4,
            value=3,
            step=1,
            label="n_weeks",
            info="Information of the past n weeks will be utilized, choose between 1 and 4"
        )
    ],
    outputs=[
        gr.Textbox(
            label="Response"
        )
    ]
)

demo.launch()