from __future__ import annotations
from typing import Iterable

import gradio as Gradio
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes

from gpt4all import GPT4All
model = GPT4All("mistral-7b-instruct-v0.1.Q4_0.gguf")

theme = Gradio.themes.Monochrome(
	primary_hue="purple",
	secondary_hue="purple",
	neutral_hue="neutral",
	radius_size=Gradio.themes.sizes.radius_sm,
	font=[Gradio.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
)

class PurpleTheme(Base):
	def __init__(
		self,
		*,
		primary_hue: colors.Color | str = colors.purple,
		secondary_hue: colors.Color | str = colors.purple,
		neutral_hue: colors.Color | str = colors.neutral,
		spacing_size: sizes.Size | str = sizes.spacing_md,
		radius_size: sizes.Size | str = sizes.radius_md,
		font: fonts.Font
		| str
		| Iterable[fonts.Font | str] = (
			fonts.GoogleFont("Inter"),
			"ui-sans-serif",
			"sans-serif",
		),
		font_mono: fonts.Font
		| str
		| Iterable[fonts.Font | str] = (
			fonts.GoogleFont("Space Grotesk"),
			"ui-monospace",
			"monospace",
		),
	):
		super().__init__(
			primary_hue=primary_hue,
			secondary_hue=secondary_hue,
			neutral_hue=neutral_hue,
			spacing_size=spacing_size,
			radius_size=radius_size,
			font=font,
			font_mono=font_mono,
		)
		super().set(
			button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
			button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
			button_primary_text_color="white",
			button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
			block_shadow="*shadow_drop_lg",
			button_shadow="*shadow_drop_lg",
			input_background_fill="zinc",
			input_border_color="*secondary_300",
			input_shadow="*shadow_drop",
			input_shadow_focus="*shadow_drop_lg",
		)

custom_theme = PurpleTheme()

ins = '''### Instruction:
{}
### Response:
'''

def run_falcon(input):
	result = ""
	for token in model.generate(ins.format(input), max_tokens=768, streaming=True, repeat_penalty=1.3, repeat_last_n=64):
		print(token)
		result += token
		yield result

with Gradio.Blocks(theme=custom_theme, analytics_enabled=False, css=".generating {visibility: hidden}") as demo:
	with Gradio.Column():
		Gradio.Markdown(
			"""
			## CogniForge
			Uses mistral (7b_0)

			Type in the box below and click the button to generate answers to your most pressing questions!
			"""
		)

		with Gradio.Box():
			instruction = Gradio.components.Textbox(placeholder="What does the Philippine flag represent?", label="Input", info="What things do you want to ask GPT4ALL?")
		with Gradio.Box():
			output = Gradio.components.Textbox(value="", label="Output", info="GPT4ALL's thoughts")
		
	submit = Gradio.Button("Generate", variant="primary")
	submit.click(run_falcon, inputs=[instruction], outputs=[output])

demo.queue(concurrency_count=1).launch(debug=True) # type: ignore