kevinconka commited on
Commit
273f5f9
·
1 Parent(s): a5d5a84

Save query + result in dataset

Browse files
Files changed (4) hide show
  1. .gitignore +163 -1
  2. app.py +35 -39
  3. chatbot.py +30 -0
  4. flagging.py +79 -0
.gitignore CHANGED
@@ -1 +1,163 @@
1
- *.html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.html
2
+ flagged/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
app.py CHANGED
@@ -1,53 +1,41 @@
1
  import urllib.request
 
 
2
 
3
- from langchain.chains import RetrievalQA
4
- from langchain_community.document_loaders import UnstructuredHTMLLoader
5
- from langchain_openai import OpenAIEmbeddings
6
- from langchain_openai import ChatOpenAI
7
- from langchain.text_splitter import CharacterTextSplitter
8
- from langchain_community.vectorstores import Chroma
9
 
10
- import gradio as gr
11
 
12
  # get the html data and save it to a file
 
 
 
 
 
 
13
  url = "https://sea.ai/faq"
14
- html = urllib.request.urlopen(url).read()
15
- with open("FAQ_SEA.AI.html", "wb") as f:
16
- f.write(html)
17
-
18
- # load documents
19
- loader = UnstructuredHTMLLoader("FAQ_SEA.AI.html")
20
- documents = loader.load()
21
- # split the documents into chunks
22
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
23
- texts = text_splitter.split_documents(documents)
24
- # select which embeddings we want to use
25
- embeddings = OpenAIEmbeddings()
26
-
27
- # create the vectorestore to use as the index
28
- db = Chroma.from_documents(texts, embeddings)
29
- # expose this index in a retriever interface
30
- retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2})
31
- # create a chain to answer questions
32
- qa = RetrievalQA.from_chain_type(
33
- llm=ChatOpenAI(),
34
- chain_type="stuff",
35
- retriever=retriever,
36
- return_source_documents=True,
37
- verbose=True,
38
- )
39
 
40
 
41
  def answer_question(message, history, system):
42
- # unwind the history of last 2 messages
43
- history = " ".join(f"{user} {bot}" for user, bot in history[-2:])
44
  # concatenate the history, message and system
45
- query = " ".join([history, message, system])
46
  retrieval_qa = qa.invoke(query)
47
  result = retrieval_qa["result"]
48
  result = result.replace('"', "").strip() # clean up the result
49
  # query = retrieval_qa["query"]
50
  # source_documents = retrieval_qa["source_documents"]
 
 
 
51
  return result
52
 
53
 
@@ -56,8 +44,11 @@ description = """
56
  <p align="center">
57
  I have memorized the entire SEA.AI FAQ page. Ask me anything about it! 🧠
58
  <br>
59
- You can modify my response by using the <code>SYSTEM</code> input under
60
- <code>Additional Inputs</code>.
 
 
 
61
  </p>
62
  """
63
 
@@ -70,7 +61,7 @@ h1 {
70
 
71
  theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo)
72
 
73
- demo = gr.ChatInterface(
74
  answer_question,
75
  title=title,
76
  description=description,
@@ -81,7 +72,12 @@ demo = gr.ChatInterface(
81
  ],
82
  css=css,
83
  theme=theme,
84
- )
 
 
 
 
 
85
 
86
  if __name__ == "__main__":
87
  demo.launch()
 
1
  import urllib.request
2
+ import gradio as gr
3
+ from huggingface_hub import get_token
4
 
5
+ from chatbot import get_retrieval_qa
6
+ from flagging import myHuggingFaceDatasetSaver
 
 
 
 
7
 
 
8
 
9
  # get the html data and save it to a file
10
+ def download_html(_url: str, _filename: str):
11
+ html = urllib.request.urlopen(_url).read()
12
+ with open(_filename, "wb") as f:
13
+ f.write(html)
14
+
15
+
16
  url = "https://sea.ai/faq"
17
+ filename = "FAQ_SEA.AI.html"
18
+ download_html(url, filename)
19
+
20
+ # load the retrieval QA model
21
+ qa = get_retrieval_qa(filename)
22
+
23
+ # dataset callback
24
+ dataset_name = "SEA-AI/seadog-chat-history"
25
+ hf_writer = myHuggingFaceDatasetSaver(get_token(), dataset_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def answer_question(message, history, system):
 
 
29
  # concatenate the history, message and system
30
+ query = " ".join([message, system])
31
  retrieval_qa = qa.invoke(query)
32
  result = retrieval_qa["result"]
33
  result = result.replace('"', "").strip() # clean up the result
34
  # query = retrieval_qa["query"]
35
  # source_documents = retrieval_qa["source_documents"]
36
+
37
+ # save the query and result to the dataset
38
+ hf_writer.flag([query, result])
39
  return result
40
 
41
 
 
44
  <p align="center">
45
  I have memorized the entire SEA.AI FAQ page. Ask me anything about it! 🧠
46
  <br>
47
+ I can't remember conversations yet, be patient with me.
48
+ <br>
49
+ DISCLAIMER: Your queries will be saved to
50
+ <a href='https://huggingface.co/datasets/SEA-AI/seadog-chat-history'>this dataset</a>
51
+ for analytics purposes.
52
  </p>
53
  """
54
 
 
61
 
62
  theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo)
63
 
64
+ with gr.ChatInterface(
65
  answer_question,
66
  title=title,
67
  description=description,
 
72
  ],
73
  css=css,
74
  theme=theme,
75
+ ) as demo:
76
+ # on page load, download the html and save it to a file
77
+ demo.load(lambda: download_html(url, filename))
78
+ # This needs to be called prior to the first call to callback.flag()
79
+ hf_writer.setup([demo.textbox, demo.chatbot], "flagged")
80
+
81
 
82
  if __name__ == "__main__":
83
  demo.launch()
chatbot.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import RetrievalQA
2
+ from langchain_community.document_loaders import UnstructuredHTMLLoader
3
+ from langchain_openai import OpenAIEmbeddings
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain.text_splitter import CharacterTextSplitter
6
+ from langchain_community.vectorstores import Chroma
7
+
8
+
9
+ def get_retrieval_qa(filename):
10
+ # load documents
11
+ loader = UnstructuredHTMLLoader(filename)
12
+ documents = loader.load()
13
+ # split the documents into chunks
14
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
15
+ texts = text_splitter.split_documents(documents)
16
+ # select which embeddings we want to use
17
+ embeddings = OpenAIEmbeddings()
18
+
19
+ # create the vectorestore to use as the index
20
+ db = Chroma.from_documents(texts, embeddings)
21
+ # expose this index in a retriever interface
22
+ retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 2})
23
+ # create a chain to answer questions
24
+ return RetrievalQA.from_chain_type(
25
+ llm=ChatOpenAI(),
26
+ chain_type="stuff",
27
+ retriever=retriever,
28
+ return_source_documents=True,
29
+ verbose=True,
30
+ )
flagging.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from pathlib import Path
3
+ from typing import Any
4
+ import gradio as gr
5
+ from gradio.flagging import HuggingFaceDatasetSaver, client_utils
6
+ import huggingface_hub
7
+
8
+
9
+ class myHuggingFaceDatasetSaver(HuggingFaceDatasetSaver):
10
+ """
11
+ Custom HuggingFaceDatasetSaver to save images/audio to disk.
12
+ Gradio's implementation seems to have a bug.
13
+ """
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__(*args, **kwargs)
17
+
18
+ def _deserialize_components(
19
+ self,
20
+ data_dir: Path,
21
+ flag_data: list[Any],
22
+ flag_option: str = "",
23
+ username: str = "",
24
+ ) -> tuple[dict[Any, Any], list[Any]]:
25
+ """Deserialize components and return the corresponding row for the flagged sample.
26
+
27
+ Images/audio are saved to disk as individual files.
28
+ """
29
+ # Components that can have a preview on dataset repos
30
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
31
+
32
+ # Generate the row corresponding to the flagged sample
33
+ features = OrderedDict()
34
+ row = []
35
+ for component, sample in zip(self.components, flag_data):
36
+ # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
37
+ label = component.label or ""
38
+ save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
39
+ save_dir.mkdir(exist_ok=True, parents=True)
40
+ if isinstance(component, gr.Chatbot):
41
+ deserialized = sample # dirty fix
42
+ else:
43
+ deserialized = component.flag(sample, save_dir)
44
+
45
+ # Add deserialized object to row
46
+ features[label] = {"dtype": "string", "_type": "Value"}
47
+ try:
48
+ assert Path(deserialized).exists()
49
+ row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
50
+ except (AssertionError, TypeError, ValueError, OSError):
51
+ deserialized = "" if deserialized is None else str(deserialized)
52
+ row.append(deserialized)
53
+
54
+ # If component is eligible for a preview, add the URL of the file
55
+ # Be mindful that images and audio can be None
56
+ if isinstance(component, tuple(file_preview_types)): # type: ignore
57
+ for _component, _type in file_preview_types.items():
58
+ if isinstance(component, _component):
59
+ features[label + " file"] = {"_type": _type}
60
+ break
61
+ if deserialized:
62
+ path_in_repo = str(
63
+ # returned filepath is absolute, we want it relative to compute URL
64
+ Path(deserialized).relative_to(self.dataset_dir)
65
+ ).replace("\\", "/")
66
+ row.append(
67
+ huggingface_hub.hf_hub_url(
68
+ repo_id=self.dataset_id,
69
+ filename=path_in_repo,
70
+ repo_type="dataset",
71
+ )
72
+ )
73
+ else:
74
+ row.append("")
75
+ features["flag"] = {"dtype": "string", "_type": "Value"}
76
+ features["username"] = {"dtype": "string", "_type": "Value"}
77
+ row.append(flag_option)
78
+ row.append(username)
79
+ return features, row