m7n's picture
Update app.py
955747f verified
raw
history blame
4.11 kB
import os
os.system("pip uninstall -y gradio")
os.system("pip install --upgrade gradio")
from pathlib import Path
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn
import gradio as gr
from datetime import datetime
import sys
gr.set_static_paths(paths=["static/"])
# create a FastAPI app
app = FastAPI()
# create a static directory to store the static files
static_dir = Path('./static')
static_dir.mkdir(parents=True, exist_ok=True)
# mount FastAPI StaticFiles server
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# Gradio stuff
import datamapplot
import numpy as np
import requests
import io
import pandas as pd
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
from itertools import chain
from compress_pickle import load, dump
def query_records(search_term):
def invert_abstract(inv_index):
if inv_index is not None:
l_inv = [(w, p) for w, pos in inv_index.items() for p in pos]
return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1])))
else:
return ' '
# Fetch records based on the search term
query = Works().search_filter(abstract=search_term)
records = []
for record in chain(*query.paginate(per_page=200)):
records.append(record)
records_df = pd.DataFrame(records)
records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']]
return records_df
def predict(text_input, progress=gr.Progress()):
# get data.
records_df = query_records(text_input)
print(records_df)
file_name = f"{datetime.utcnow().strftime('%s')}.html"
file_path = static_dir / file_name
print(file_path)
#
progress(0.7, desc="Loading hover data...")
plot = datamapplot.create_interactive_plot(
basedata_df[['x','y']].values,
np.array(basedata_df['cluster_1_labels']),
hover_text=[str(ix) + ', ' + str(row['parsed_publication']) + str(row['title']) for ix, row in basedata_df.iterrows()],
font_family="Roboto Condensed",
)
progress(0.9, desc="Saving plot...")
plot.save(file_path)
progress(1.0, desc="Done!")
iframe = f"""<iframe src="/static/{file_name}" width="100%" height="500px"></iframe>"""
link = f'<a href="/static/{file_name}" target="_blank">{file_name}</a>'
return link, iframe
with gr.Blocks() as block:
gr.Markdown("""
## Gradio + FastAPI + Static Server
This is a demo of how to use Gradio with FastAPI and a static server.
The Gradio app generates dynamic HTML files and stores them in a static directory. FastAPI serves the static files.
""")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Name")
markdown = gr.Markdown(label="Output Box")
new_btn = gr.Button("New")
with gr.Column():
html = gr.HTML(label="HTML preview", show_label=True)
new_btn.click(fn=predict, inputs=[text_input], outputs=[markdown, html])
def setup_basemap_data():
# get data.
print("getting basemap data...")
basedata_file= requests.get(
"https://www.maxnoichl.eu/full/oa_project_on_scimap_background_data/100k_filtered_OA_sample_cluster_and_positions.bz"
)
# Write the response content to a .bz file in the static directory
static_dir = Path("static")
static_dir.mkdir(exist_ok=True)
bz_file_name = "100k_filtered_OA_sample_cluster_and_positions.bz"
bz_file_path = static_dir / bz_file_name
with open(bz_file_path, "wb") as f:
f.write(basedata_file.content)
# Load the data from the saved .bz file
basedata_df = load(bz_file_path)
return basedata_df
basedata_df = setup_basemap_data()
# mount Gradio app to FastAPI app
app = gr.mount_gradio_app(app, block, path="/")
# serve the app
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
# run the app with
# python app.py
# or
# uvicorn "app:app" --host "0.0.0.0" --port 7860 --reload