Spaces:
Sleeping
Sleeping
File size: 3,216 Bytes
0f64bae 7b47aa3 48d9af7 c21a510 48d9af7 19e42bb 0f64bae 7b47aa3 48d9af7 0f64bae 48d9af7 0f64bae c21a510 0f64bae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from typing import Optional, Type, Annotated
from pydantic import BaseModel, Field
from langchain.tools import BaseTool
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
import csv
import uuid
import os
class FlashcardInput(BaseModel):
flashcards: list = Field(description="A list of flashcards. Each flashcard should be a dictionary with 'question' and 'answer' keys.")
class FlashcardTool(BaseTool):
"""
FlashcardTool class.
This class represents a tool for creating flashcards in a .csv format suitable for import into Anki.
Attributes:
name (str): The name of the tool.
description (str): The description of the tool.
args_schema (Type[BaseModel]): The schema for the input arguments of the tool.
Methods:
_run(flashcards: list, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
Use the tool to create flashcards.
_arun(flashcards: list, run_manager: Optional[AsyncCallbackManagerForToolRun] = None) -> str:
Use the tool asynchronously.
"""
name = "create_flashcards"
description = "Create flashcards in a .csv format suitable for import into Anki"
args_schema: Type[BaseModel] = FlashcardInput
def _run(
self, flashcards: list, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
"""Use the tool to create flashcards."""
filename = f"flashcards_{uuid.uuid4()}.csv"
save_path = os.path.join('flashcards', filename)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'w', newline='') as csvfile:
fieldnames = ['Front', 'Back']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for card in flashcards:
writer.writerow({'Front': card['question'], 'Back': card['answer']})
print("\033[93m" + f"Flashcards successfully created and saved to {save_path}" + "\033[0m")
return "csv file created successfully."
async def _arun(
self, flashcards: list, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("create_flashcards does not support async")
# Instantiate the tool
create_flashcards_tool = FlashcardTool()
class RetrievalChainWrapper:
"""
RetrievalChainWrapper class.
This class wraps a retrieval chain and provides a method to retrieve information using the wrapped chain.
Attributes:
retrieval_chain: The retrieval chain to be wrapped.
Methods:
retrieve_information(query: str) -> str:
Use this tool to retrieve information about the provided notebook.
"""
def __init__(self, retrieval_chain):
self.retrieval_chain = retrieval_chain
def retrieve_information(
self,
query: Annotated[str, "query to ask the RAG tool"]
):
"""Use this tool to retrieve information about the provided notebook."""
response = self.retrieval_chain.invoke({"question": query})
return response["response"].content
|