mtyrrell commited on
Commit
23162a1
·
1 Parent(s): 8170b18

refactored generator

Browse files
Files changed (4) hide show
  1. README.md +37 -30
  2. app/generator.py +224 -0
  3. app/main.py +9 -8
  4. app/utils.py +3 -135
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Chatfed Generation Service
3
  emoji: 🤖
4
  colorFrom: blue
5
  colorTo: purple
@@ -8,46 +8,53 @@ pinned: false
8
  license: mit
9
  ---
10
 
11
- # Generation Module
12
 
13
- This is an LLM-based generation service designed to be deployed as a modular component of a broader RAG system. The service runs on a docker container and exposes a gradio UI on port 7860 as well as an MCP endpoint.
14
-
15
- ## Configuration
16
-
17
- 1. The module requires an API key (set as an environment variable) for an inference provider to run. Multiple inference providers are supported. Make sure to set the appropriate environment variables:
18
- - OpenAI: `OPENAI_API_KEY`
19
- - Anthropic: `ANTHROPIC_API_KEY`
20
- - Cohere: `COHERE_API_KEY`
21
- - HuggingFace: `HF_TOKEN`
22
-
23
- 2. Inference provider and model settings are accessible via params.cfg
24
 
25
  ## MCP Endpoint
26
 
27
- ## Available Tools
28
 
29
- ### `rag_generate`
 
 
30
 
31
- Generate an answer to a query using provided context through RAG. This function takes a user query and relevant context, then uses a language model to generate a comprehensive answer based on the provided information.
32
 
33
- **Input Schema:**
 
 
34
 
35
- | Parameter | Type | Description |
36
- |-----------|------|-------------|
37
- | `query` | string | The user's question or query |
38
- | `context` | string | The relevant context/documents to use for answering |
 
 
 
 
39
 
40
- **Returns:** The generated answer based on the query and context
41
 
42
- **Example Usage:**
 
 
 
 
 
43
 
44
- ```json
45
- {
46
- "query": "What are the benefits of renewable energy?",
47
- "context": "Documents and information about renewable energy sources..."
48
- }
49
  ```
50
 
51
- ---
 
 
 
 
 
 
 
52
 
53
- *This tool uses an LLM to generate answers using the most relevant information from the context, along with the input query.*
 
1
  ---
2
+ title: ChatFed Generator
3
  emoji: 🤖
4
  colorFrom: blue
5
  colorTo: purple
 
8
  license: mit
9
  ---
10
 
11
+ # ChatFed Generator - MCP Server
12
 
13
+ A language model-based generation service designed for ChatFed RAG (Retrieval-Augmented Generation) pipelines. This module serves as an **MCP (Model Context Protocol) server** that generates contextual responses using configurable LLM providers with support for retrieval result processing.
 
 
 
 
 
 
 
 
 
 
14
 
15
  ## MCP Endpoint
16
 
17
+ The main MCP function is `generate` which provides context-aware text generation using configurable LLM providers when properly configured with API credentials.
18
 
19
+ **Parameters**:
20
+ - `query` (str, required): The question or query to be answered
21
+ - `context` (str|list, required): Context for answering - can be plain text or list of retrieval result dictionaries
22
 
23
+ **Returns**: String containing the generated answer based on the provided context and query.
24
 
25
+ **Example usage**:
26
+ ```python
27
+ from gradio_client import Client
28
 
29
+ client = Client("ENTER CONTAINER URL / SPACE ID")
30
+ result = client.predict(
31
+ query="What are the key findings?",
32
+ context="Your relevant documents or context here...",
33
+ api_name="/generate"
34
+ )
35
+ print(result)
36
+ ```
37
 
38
+ ## Configuration
39
 
40
+ ### LLM Provider Configuration
41
+ 1. Set your preferred inference provider in `params.cfg`
42
+ 2. Configure the model and generation parameters
43
+ 3. Set the required API key environment variable
44
+ 4. [Optional] Adjust temperature and max_tokens settings
45
+ 5. Run the app:
46
 
47
+ ```bash
48
+ docker build -t chatfed-generator .
49
+ docker run -p 7860:7860 chatfed-generator
 
 
50
  ```
51
 
52
+ ## Environment Variables Required
53
+
54
+ # Make sure to set the appropriate environment variables:
55
+ # - OpenAI: `OPENAI_API_KEY`
56
+ # - Anthropic: `ANTHROPIC_API_KEY`
57
+ # - Cohere: `COHERE_API_KEY`
58
+ # - HuggingFace: `HF_TOKEN`
59
+
60
 
 
app/generator.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import asyncio
3
+ import json
4
+ import ast
5
+ from typing import List, Dict, Any, Union
6
+ from dotenv import load_dotenv
7
+
8
+ # LangChain imports
9
+ from langchain_openai import ChatOpenAI
10
+ from langchain_anthropic import ChatAnthropic
11
+ from langchain_cohere import ChatCohere
12
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+
15
+ # Local imports
16
+ from .utils import getconfig, get_auth
17
+
18
+ # ---------------------------------------------------------------------
19
+ # Model / client initialization (non exaustive list of providers)
20
+ # ---------------------------------------------------------------------
21
+ config = getconfig("params.cfg")
22
+
23
+ PROVIDER = config.get("generator", "PROVIDER")
24
+ MODEL = config.get("generator", "MODEL")
25
+ MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
26
+ TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
27
+
28
+ # Set up authentication for the selected provider
29
+ auth_config = get_auth(PROVIDER)
30
+
31
+ def get_chat_model():
32
+ """Initialize the appropriate LangChain chat model based on provider"""
33
+ common_params = {
34
+ "temperature": TEMPERATURE,
35
+ "max_tokens": MAX_TOKENS,
36
+ }
37
+
38
+ if PROVIDER == "openai":
39
+ return ChatOpenAI(
40
+ model=MODEL,
41
+ openai_api_key=auth_config["api_key"],
42
+ **common_params
43
+ )
44
+ elif PROVIDER == "anthropic":
45
+ return ChatAnthropic(
46
+ model=MODEL,
47
+ anthropic_api_key=auth_config["api_key"],
48
+ **common_params
49
+ )
50
+ elif PROVIDER == "cohere":
51
+ return ChatCohere(
52
+ model=MODEL,
53
+ cohere_api_key=auth_config["api_key"],
54
+ **common_params
55
+ )
56
+ elif PROVIDER == "huggingface":
57
+ # Initialize HuggingFaceEndpoint with explicit parameters
58
+ llm = HuggingFaceEndpoint(
59
+ repo_id=MODEL,
60
+ huggingfacehub_api_token=auth_config["api_key"],
61
+ task="text-generation",
62
+ temperature=TEMPERATURE,
63
+ max_new_tokens=MAX_TOKENS
64
+ )
65
+ return ChatHuggingFace(llm=llm)
66
+ else:
67
+ raise ValueError(f"Unsupported provider: {PROVIDER}")
68
+
69
+ # Initialize provider-agnostic chat model
70
+ chat_model = get_chat_model()
71
+
72
+ # ---------------------------------------------------------------------
73
+ # Context processing - may need further refinement (i.e. to manage other data sources)
74
+ # ---------------------------------------------------------------------
75
+ def extract_relevant_fields(retrieval_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
76
+ """
77
+ Extract only relevant fields from retrieval results.
78
+
79
+ Args:
80
+ retrieval_results: List of JSON objects from retriever
81
+
82
+ Returns:
83
+ List of processed objects with only relevant fields
84
+ """
85
+
86
+ retrieval_results = ast.literal_eval(retrieval_results)
87
+
88
+ processed_results = []
89
+
90
+ for result in retrieval_results:
91
+ # Extract the answer content
92
+ answer = result.get('answer', '')
93
+
94
+ # Extract document identification from metadata
95
+ metadata = result.get('answer_metadata', {})
96
+ doc_info = {
97
+ 'answer': answer,
98
+ 'filename': metadata.get('filename', 'Unknown'),
99
+ 'page': metadata.get('page', 'Unknown'),
100
+ 'year': metadata.get('year', 'Unknown'),
101
+ 'source': metadata.get('source', 'Unknown'),
102
+ 'document_id': metadata.get('_id', 'Unknown')
103
+ }
104
+
105
+ processed_results.append(doc_info)
106
+
107
+ return processed_results
108
+
109
+ def format_context_from_results(processed_results: List[Dict[str, Any]]) -> str:
110
+ """
111
+ Format processed retrieval results into a context string for the LLM.
112
+
113
+ Args:
114
+ processed_results: List of processed objects with relevant fields
115
+
116
+ Returns:
117
+ Formatted context string
118
+ """
119
+ if not processed_results:
120
+ return ""
121
+
122
+ context_parts = []
123
+
124
+ for i, result in enumerate(processed_results, 1):
125
+ doc_reference = f"[Document {i}: {result['filename']}"
126
+ if result['page'] != 'Unknown':
127
+ doc_reference += f", Page {result['page']}"
128
+ if result['year'] != 'Unknown':
129
+ doc_reference += f", Year {result['year']}"
130
+ doc_reference += "]"
131
+
132
+ context_part = f"{doc_reference}\n{result['answer']}\n"
133
+ context_parts.append(context_part)
134
+
135
+ return "\n".join(context_parts)
136
+
137
+ # ---------------------------------------------------------------------
138
+ # Core generation function for both Gradio UI and MCP
139
+ # ---------------------------------------------------------------------
140
+ async def _call_llm(messages: list) -> str:
141
+ """
142
+ Provider-agnostic LLM call using LangChain.
143
+
144
+ Args:
145
+ messages: List of LangChain message objects
146
+
147
+ Returns:
148
+ Generated response content as string
149
+ """
150
+ try:
151
+ # Use async invoke for better performance
152
+ response = await chat_model.ainvoke(messages)
153
+ return response.content.strip()
154
+ except Exception as e:
155
+ logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
156
+ raise
157
+
158
+ def build_messages(question: str, context: str) -> list:
159
+ """
160
+ Build messages in LangChain format.
161
+
162
+ Args:
163
+ question: The user's question
164
+ context: The relevant context for answering
165
+
166
+ Returns:
167
+ List of LangChain message objects
168
+ """
169
+ system_content = (
170
+ "You are an expert assistant. Answer the USER question using only the "
171
+ "CONTEXT provided. If the context is insufficient say 'I don't know.'"
172
+ )
173
+
174
+ user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
175
+
176
+ return [
177
+ SystemMessage(content=system_content),
178
+ HumanMessage(content=user_content)
179
+ ]
180
+
181
+
182
+ async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str:
183
+ """
184
+ Generate an answer to a query using provided context through RAG.
185
+
186
+ This function takes a user query and relevant context, then uses a language model
187
+ to generate a comprehensive answer based on the provided information.
188
+
189
+ Args:
190
+ query (str): User query
191
+ context (list): List of retrieval result objects (dictionaries)
192
+ Returns:
193
+ str: The generated answer based on the query and context
194
+ """
195
+ if not query.strip():
196
+ return "Error: Query cannot be empty"
197
+
198
+ # Handle both string context (for Gradio UI) and list context (from retriever)
199
+ if isinstance(context, list):
200
+ if not context:
201
+ return "Error: No retrieval results provided"
202
+
203
+ # Process the retrieval results
204
+ processed_results = extract_relevant_fields(context)
205
+ formatted_context = format_context_from_results(processed_results)
206
+
207
+ if not formatted_context.strip():
208
+ return "Error: No valid content found in retrieval results"
209
+
210
+ elif isinstance(context, str):
211
+ if not context.strip():
212
+ return "Error: Context cannot be empty"
213
+ formatted_context = context
214
+
215
+ else:
216
+ return "Error: Context must be either a string or list of retrieval results"
217
+
218
+ try:
219
+ messages = build_messages(query, formatted_context)
220
+ answer = await _call_llm(messages)
221
+ return answer
222
+ except Exception as e:
223
+ logging.exception("Generation failed")
224
+ return f"Error: {str(e)}"
app/main.py CHANGED
@@ -1,23 +1,23 @@
1
  import gradio as gr
2
- from .utils import rag_generate
3
 
4
  # ---------------------------------------------------------------------
5
  # Gradio Interface with MCP support
6
  # ---------------------------------------------------------------------
7
  ui = gr.Interface(
8
- fn=rag_generate,
9
  inputs=[
10
  gr.Textbox(
11
  label="Query",
12
  lines=2,
13
- placeholder="What would you like to know?",
14
- info="Enter your question here"
15
  ),
16
  gr.Textbox(
17
  label="Context",
18
  lines=8,
19
- placeholder="Paste relevant documents or context here...",
20
- info="Provide the context/documents to use for answering"
21
  ),
22
  ],
23
  outputs=gr.Textbox(
@@ -25,8 +25,9 @@ ui = gr.Interface(
25
  lines=6,
26
  show_copy_button=True
27
  ),
28
- title="RAG Generation Service",
29
- description="Ask questions based on provided context. Intended for use in RAG pipelines (i.e. context supplied by semantic retriever service) as an MCP server.",
 
30
  )
31
 
32
  # Launch with MCP server enabled
 
1
  import gradio as gr
2
+ from .generator import generate
3
 
4
  # ---------------------------------------------------------------------
5
  # Gradio Interface with MCP support
6
  # ---------------------------------------------------------------------
7
  ui = gr.Interface(
8
+ fn=generate,
9
  inputs=[
10
  gr.Textbox(
11
  label="Query",
12
  lines=2,
13
+ placeholder="Enter query here",
14
+ info="The query to search for in the vector database"
15
  ),
16
  gr.Textbox(
17
  label="Context",
18
  lines=8,
19
+ placeholder="Paste relevant context here",
20
+ info="Provide the context/documents to use for answering. The API expects a list of dictionaries, but the UI should except anything"
21
  ),
22
  ],
23
  outputs=gr.Textbox(
 
25
  lines=6,
26
  show_copy_button=True
27
  ),
28
+ title="ChatFed Generation Module",
29
+ description="Ask questions based on provided context. Intended for use in RAG pipelines as an MCP server with other ChatFed modules (i.e. context supplied by semantic retriever service).",
30
+ api_name="generate"
31
  )
32
 
33
  # Launch with MCP server enabled
app/utils.py CHANGED
@@ -1,14 +1,9 @@
1
- import os, asyncio, logging
2
  import configparser
3
  import logging
4
  from dotenv import load_dotenv
5
 
6
- # LangChain imports
7
- from langchain_openai import ChatOpenAI
8
- from langchain_anthropic import ChatAnthropic
9
- from langchain_cohere import ChatCohere
10
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
11
- from langchain_core.messages import SystemMessage, HumanMessage
12
 
13
  # Local .env file
14
  load_dotenv()
@@ -30,7 +25,7 @@ def getconfig(configfile_path: str):
30
  # ---------------------------------------------------------------------
31
  # Provider-agnostic authentication and configuration
32
  # ---------------------------------------------------------------------
33
- def get_auth_config(provider: str) -> dict:
34
  """Get authentication configuration for different providers"""
35
  auth_configs = {
36
  "openai": {"api_key": os.getenv("OPENAI_API_KEY")},
@@ -49,130 +44,3 @@ def get_auth_config(provider: str) -> dict:
49
  raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
50
 
51
  return auth_config
52
-
53
- # ---------------------------------------------------------------------
54
- # Model / client initialization
55
- # ---------------------------------------------------------------------
56
- config = getconfig("params.cfg")
57
-
58
- PROVIDER = config.get("generator", "PROVIDER")
59
- MODEL = config.get("generator", "MODEL")
60
- MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
61
- TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
62
-
63
- # Set up authentication for the selected provider
64
- auth_config = get_auth_config(PROVIDER)
65
-
66
- def get_chat_model():
67
- """Initialize the appropriate LangChain chat model based on provider"""
68
- common_params = {
69
- "temperature": TEMPERATURE,
70
- "max_tokens": MAX_TOKENS,
71
- }
72
-
73
- if PROVIDER == "openai":
74
- return ChatOpenAI(
75
- model=MODEL,
76
- openai_api_key=auth_config["api_key"],
77
- **common_params
78
- )
79
- elif PROVIDER == "anthropic":
80
- return ChatAnthropic(
81
- model=MODEL,
82
- anthropic_api_key=auth_config["api_key"],
83
- **common_params
84
- )
85
- elif PROVIDER == "cohere":
86
- return ChatCohere(
87
- model=MODEL,
88
- cohere_api_key=auth_config["api_key"],
89
- **common_params
90
- )
91
- elif PROVIDER == "huggingface":
92
- # Initialize HuggingFaceEndpoint with explicit parameters
93
- llm = HuggingFaceEndpoint(
94
- repo_id=MODEL,
95
- huggingfacehub_api_token=auth_config["api_key"],
96
- task="text-generation",
97
- temperature=TEMPERATURE,
98
- max_new_tokens=MAX_TOKENS
99
- )
100
- return ChatHuggingFace(llm=llm)
101
- else:
102
- raise ValueError(f"Unsupported provider: {PROVIDER}")
103
-
104
- # Initialize provider-agnostic chat model
105
- chat_model = get_chat_model()
106
-
107
- # ---------------------------------------------------------------------
108
- # Core generation function for both Gradio UI and MCP
109
- # ---------------------------------------------------------------------
110
- async def _call_llm(messages: list) -> str:
111
- """
112
- Provider-agnostic LLM call using LangChain.
113
-
114
- Args:
115
- messages: List of LangChain message objects
116
-
117
- Returns:
118
- Generated response content as string
119
- """
120
- try:
121
- # Use async invoke for better performance
122
- response = await chat_model.ainvoke(messages)
123
- return response.content.strip()
124
- except Exception as e:
125
- logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
126
- raise
127
-
128
- def build_messages(question: str, context: str) -> list:
129
- """
130
- Build messages in LangChain format.
131
-
132
- Args:
133
- question: The user's question
134
- context: The relevant context for answering
135
-
136
- Returns:
137
- List of LangChain message objects
138
- """
139
- system_content = (
140
- "You are an expert assistant. Answer the USER question using only the "
141
- "CONTEXT provided. If the context is insufficient say 'I don't know.'"
142
- )
143
-
144
- user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
145
-
146
- return [
147
- SystemMessage(content=system_content),
148
- HumanMessage(content=user_content)
149
- ]
150
-
151
-
152
- async def rag_generate(query: str, context: str) -> str:
153
- """
154
- Generate an answer to a query using provided context through RAG.
155
-
156
- This function takes a user query and relevant context, then uses a language model
157
- to generate a comprehensive answer based on the provided information.
158
-
159
- Args:
160
- query (str): The user's question or query
161
- context (str): The relevant context/documents to use for answering
162
-
163
- Returns:
164
- str: The generated answer based on the query and context
165
- """
166
- if not query.strip():
167
- return "Error: Query cannot be empty"
168
-
169
- if not context.strip():
170
- return "Error: Context cannot be empty"
171
-
172
- try:
173
- messages = build_messages(query, context)
174
- answer = await _call_llm(messages)
175
- return answer
176
- except Exception as e:
177
- logging.exception("Generation failed")
178
- return f"Error: {str(e)}"
 
1
+ import os
2
  import configparser
3
  import logging
4
  from dotenv import load_dotenv
5
 
6
+
 
 
 
 
 
7
 
8
  # Local .env file
9
  load_dotenv()
 
25
  # ---------------------------------------------------------------------
26
  # Provider-agnostic authentication and configuration
27
  # ---------------------------------------------------------------------
28
+ def get_auth(provider: str) -> dict:
29
  """Get authentication configuration for different providers"""
30
  auth_configs = {
31
  "openai": {"api_key": os.getenv("OPENAI_API_KEY")},
 
44
  raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
45
 
46
  return auth_config