443 lines
19 KiB
Python
443 lines
19 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import logging
|
|
from typing import AsyncGenerator, Dict, Any, Optional, List
|
|
import base64
|
|
|
|
from google.cloud import speech
|
|
from google.cloud import texttospeech
|
|
from google.api_core import exceptions
|
|
import openai
|
|
|
|
from config import config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class SpeechToTextService:
|
|
def __init__(self, language_code: str = "en-US"):
|
|
self.client = speech.SpeechClient()
|
|
self.language_code = language_code
|
|
|
|
encoding_map = {
|
|
"WEBM_OPUS": speech.RecognitionConfig.AudioEncoding.WEBM_OPUS,
|
|
"LINEAR16": speech.RecognitionConfig.AudioEncoding.LINEAR16,
|
|
"FLAC": speech.RecognitionConfig.AudioEncoding.FLAC,
|
|
"MULAW": speech.RecognitionConfig.AudioEncoding.MULAW,
|
|
"AMR": speech.RecognitionConfig.AudioEncoding.AMR,
|
|
"AMR_WB": speech.RecognitionConfig.AudioEncoding.AMR_WB,
|
|
"OGG_OPUS": speech.RecognitionConfig.AudioEncoding.OGG_OPUS,
|
|
"MP3": speech.RecognitionConfig.AudioEncoding.MP3,
|
|
}
|
|
|
|
self.recognition_config = speech.RecognitionConfig(
|
|
encoding=encoding_map.get(config.SPEECH_ENCODING, speech.RecognitionConfig.AudioEncoding.WEBM_OPUS),
|
|
sample_rate_hertz=config.SPEECH_SAMPLE_RATE,
|
|
language_code=self.language_code,
|
|
enable_automatic_punctuation=True,
|
|
use_enhanced=True,
|
|
model="latest_long",
|
|
)
|
|
|
|
self.streaming_config = speech.StreamingRecognitionConfig(
|
|
config=self.recognition_config,
|
|
interim_results=True,
|
|
single_utterance=False,
|
|
)
|
|
|
|
async def transcribe_streaming(self, audio_generator: AsyncGenerator[bytes, None]) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""Stream audio data to Google Cloud Speech-to-Text and yield transcription results."""
|
|
try:
|
|
async def request_generator():
|
|
yield speech.StreamingRecognizeRequest(streaming_config=self.streaming_config)
|
|
|
|
async for chunk in audio_generator:
|
|
yield speech.StreamingRecognizeRequest(audio_content=chunk)
|
|
|
|
responses = self.client.streaming_recognize(request_generator())
|
|
|
|
for response in responses:
|
|
for result in response.results:
|
|
transcript = result.alternatives[0].transcript
|
|
is_final = result.is_final
|
|
|
|
yield {
|
|
"type": "transcription",
|
|
"transcript": transcript,
|
|
"is_final": is_final,
|
|
"confidence": result.alternatives[0].confidence if is_final else 0.0
|
|
}
|
|
except exceptions.GoogleAPICallError as e:
|
|
yield {
|
|
"type": "error",
|
|
"message": f"Speech recognition error: {str(e)}"
|
|
}
|
|
|
|
class TextToSpeechService:
|
|
def __init__(self, language_code: str = "en-US"):
|
|
self.client = texttospeech.TextToSpeechClient()
|
|
self.language_code = language_code
|
|
|
|
self.gender_map = {
|
|
"FEMALE": texttospeech.SsmlVoiceGender.FEMALE,
|
|
"MALE": texttospeech.SsmlVoiceGender.MALE,
|
|
"NEUTRAL": texttospeech.SsmlVoiceGender.NEUTRAL,
|
|
"male": texttospeech.SsmlVoiceGender.MALE,
|
|
"female": texttospeech.SsmlVoiceGender.FEMALE,
|
|
}
|
|
|
|
def _get_voice_config(self, gender: str, character_name: str = None) -> Dict[str, Any]:
|
|
"""Override this method in language-specific implementations"""
|
|
tts_gender = self.gender_map.get(gender, texttospeech.SsmlVoiceGender.FEMALE)
|
|
|
|
return {
|
|
"name": f"{self.language_code}-Standard-A",
|
|
"speaking_rate": 1.0,
|
|
"pitch": None,
|
|
"ssml_gender": tts_gender,
|
|
}
|
|
|
|
def _get_voice_and_audio_config(self, gender: str, character_name: str = None) -> tuple:
|
|
"""Get appropriate voice and audio configuration based on gender."""
|
|
config_set = self._get_voice_config(gender, character_name)
|
|
|
|
voice = texttospeech.VoiceSelectionParams(
|
|
language_code=self.language_code,
|
|
name=config_set["name"],
|
|
ssml_gender=config_set["ssml_gender"],
|
|
)
|
|
|
|
audio_config_params = {
|
|
"audio_encoding": texttospeech.AudioEncoding.MP3, # MP3 for faster processing
|
|
"speaking_rate": config_set["speaking_rate"],
|
|
# Remove effects profile for faster generation
|
|
}
|
|
|
|
if config_set["pitch"] is not None:
|
|
audio_config_params["pitch"] = config_set["pitch"]
|
|
|
|
audio_config = texttospeech.AudioConfig(**audio_config_params)
|
|
|
|
return voice, audio_config
|
|
|
|
async def synthesize_speech(self, text: str, gender: str = "female", character_name: str = None) -> bytes:
|
|
"""Convert text to speech using Google Cloud Text-to-Speech."""
|
|
try:
|
|
logger.info(f"TTS synthesize_speech called with text: '{text}', gender: '{gender}', character: '{character_name}'")
|
|
|
|
voice, audio_config = self._get_voice_and_audio_config(gender, character_name)
|
|
logger.info(f"Using voice: {voice.name}, language: {self.language_code}")
|
|
|
|
synthesis_input = texttospeech.SynthesisInput(text=text)
|
|
|
|
response = self.client.synthesize_speech(
|
|
input=synthesis_input,
|
|
voice=voice,
|
|
audio_config=audio_config,
|
|
)
|
|
|
|
logger.info(f"TTS successful, audio length: {len(response.audio_content)} bytes")
|
|
return response.audio_content
|
|
|
|
except exceptions.GoogleAPICallError as e:
|
|
logger.error(f"Text-to-speech error: {str(e)}")
|
|
raise Exception(f"Text-to-speech error: {str(e)}")
|
|
|
|
class BaseAIConversationService:
|
|
def __init__(self, language_code: str = "en"):
|
|
self.client = openai.OpenAI(api_key=config.OPENAI_API_KEY)
|
|
self.model = config.OPENAI_MODEL
|
|
self.language_code = language_code
|
|
self.current_personality = None
|
|
self.conversation_history: List[Dict[str, str]] = []
|
|
self.goal_progress: List = []
|
|
|
|
def set_personality(self, personality):
|
|
"""Set the current personality for the conversation."""
|
|
self.current_personality = personality
|
|
self.conversation_history = []
|
|
if hasattr(personality, 'goal_items'):
|
|
self.goal_progress = [item.dict() for item in personality.goal_items]
|
|
|
|
def reset_conversation(self):
|
|
"""Reset the conversation history."""
|
|
self.conversation_history = []
|
|
if self.current_personality and hasattr(self.current_personality, 'goal_items'):
|
|
self.goal_progress = [item.dict() for item in self.current_personality.goal_items]
|
|
|
|
def get_personality_for_scenario(self, scenario: str, character_name: str = None):
|
|
"""Override in language-specific implementations"""
|
|
raise NotImplementedError("Must be implemented by language-specific service")
|
|
|
|
async def check_goal_completion(self, user_message: str, ai_response: str) -> bool:
|
|
"""Check if any goals are completed using LLM judge."""
|
|
if not self.goal_progress:
|
|
return False
|
|
|
|
goals_completed = False
|
|
incomplete_goals = [g for g in self.goal_progress if not g.get('completed', False)]
|
|
if not incomplete_goals:
|
|
return False
|
|
|
|
logger.info(f"Checking goal completion for user message: '{user_message}'")
|
|
|
|
conversation_context = ""
|
|
for exchange in self.conversation_history[-3:]:
|
|
conversation_context += f"User: {exchange['user']}\nAI: {exchange['assistant']}\n"
|
|
|
|
for goal in incomplete_goals:
|
|
completion_check = await self._judge_goal_completion(
|
|
goal,
|
|
user_message,
|
|
ai_response,
|
|
conversation_context
|
|
)
|
|
|
|
if completion_check:
|
|
goal['completed'] = True
|
|
goals_completed = True
|
|
logger.info(f"✅ Goal completed: {goal['description']}")
|
|
|
|
return goals_completed
|
|
|
|
async def _judge_goal_completion(self, goal, user_message: str, ai_response: str, conversation_context: str) -> bool:
|
|
"""Use LLM to judge if a specific goal was completed."""
|
|
try:
|
|
if "order" in goal['description'].lower() or "buy" in goal['description'].lower():
|
|
judge_prompt = f"""You are a strict judge determining if a specific goal was FULLY completed in a conversation.
|
|
|
|
GOAL TO CHECK: {goal['description']}
|
|
|
|
RECENT CONVERSATION CONTEXT:
|
|
{conversation_context}
|
|
|
|
LATEST EXCHANGE:
|
|
User: {user_message}
|
|
AI: {ai_response}
|
|
|
|
CRITICAL RULES FOR ORDERING GOALS:
|
|
1. ONLY return "YES" if the user has COMPLETELY finished this exact goal
|
|
2. Return "NO" if the goal is partial, incomplete, or just being discussed
|
|
3. For "Order [item]" goals: user must explicitly say they want/order that EXACT item
|
|
4. Don't mark as complete just because the AI is asking about it
|
|
|
|
Answer ONLY "YES" or "NO":"""
|
|
else:
|
|
judge_prompt = f"""You are judging if a conversational goal was completed in a natural conversation scenario.
|
|
|
|
GOAL TO CHECK: {goal['description']}
|
|
|
|
RECENT CONVERSATION CONTEXT:
|
|
{conversation_context}
|
|
|
|
LATEST EXCHANGE:
|
|
User: {user_message}
|
|
AI: {ai_response}
|
|
|
|
RULES FOR CONVERSATION GOALS:
|
|
1. Return "YES" if the user has naturally accomplished this conversational goal
|
|
2. Goals can be completed through natural conversation flow
|
|
3. Check the FULL conversation context, not just the latest exchange
|
|
|
|
Answer ONLY "YES" or "NO":"""
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[{"role": "user", "content": judge_prompt}],
|
|
max_tokens=5,
|
|
temperature=0.1,
|
|
)
|
|
|
|
result = response.choices[0].message.content.strip().upper()
|
|
return result == "YES"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in goal completion judge: {str(e)}")
|
|
return False
|
|
|
|
def are_all_goals_completed(self) -> bool:
|
|
"""Check if all goals are completed."""
|
|
return all(goal.get('completed', False) for goal in self.goal_progress)
|
|
|
|
def get_goal_status(self) -> Dict[str, Any]:
|
|
"""Get current goal status."""
|
|
return {
|
|
"scenario_goal": self.current_personality.scenario_goal if self.current_personality else "",
|
|
"goal_items": [
|
|
{
|
|
"id": goal.get('id'),
|
|
"description": goal.get('description'),
|
|
"completed": goal.get('completed', False)
|
|
} for goal in self.goal_progress
|
|
],
|
|
"all_completed": self.are_all_goals_completed()
|
|
}
|
|
|
|
async def get_goal_status_async(self) -> Dict[str, Any]:
|
|
"""Async version of get_goal_status for parallel processing."""
|
|
return self.get_goal_status()
|
|
|
|
async def get_response(self, user_message: str, context: str = "") -> str:
|
|
"""Get AI response to user message using current personality."""
|
|
try:
|
|
if not self.current_personality:
|
|
raise Exception("No personality set")
|
|
|
|
system_prompt = self.current_personality.get_system_prompt(context)
|
|
|
|
messages = [{"role": "system", "content": system_prompt}]
|
|
|
|
recent_history = self.conversation_history[-8:] if len(self.conversation_history) > 8 else self.conversation_history
|
|
for exchange in recent_history:
|
|
messages.append({"role": "user", "content": exchange["user"]})
|
|
messages.append({"role": "assistant", "content": exchange["assistant"]})
|
|
|
|
messages.append({"role": "user", "content": user_message})
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
max_tokens=250,
|
|
temperature=0.7,
|
|
)
|
|
|
|
ai_response = response.choices[0].message.content
|
|
|
|
self.conversation_history.append({
|
|
"user": user_message,
|
|
"assistant": ai_response
|
|
})
|
|
|
|
await self.check_goal_completion(user_message, ai_response)
|
|
|
|
return ai_response
|
|
except Exception as e:
|
|
return f"Sorry, there was an error: {str(e)}"
|
|
|
|
class BaseConversationFlowService:
|
|
def __init__(self, language_code: str = "en-US"):
|
|
self.language_code = language_code
|
|
self.stt_service = SpeechToTextService(language_code)
|
|
self.tts_service = TextToSpeechService(language_code)
|
|
self.ai_service = BaseAIConversationService(language_code.split('-')[0])
|
|
|
|
def set_scenario_personality(self, scenario: str, character_name: str = None):
|
|
"""Set the personality based on scenario and character."""
|
|
personality = self.ai_service.get_personality_for_scenario(scenario, character_name)
|
|
if not self.ai_service.current_personality or self.ai_service.current_personality.name != personality.name:
|
|
logger.info(f"Setting new personality: {personality.name}")
|
|
self.ai_service.set_personality(personality)
|
|
|
|
async def generate_initial_greeting(self, scenario_context: str = "") -> Dict[str, Any]:
|
|
"""Generate initial greeting from character."""
|
|
try:
|
|
scenario = self.extract_scenario_from_context(scenario_context)
|
|
if scenario:
|
|
self.set_scenario_personality(scenario)
|
|
|
|
# Generate greeting based on personality
|
|
personality = self.ai_service.current_personality
|
|
if personality and personality.typical_phrases:
|
|
greeting = personality.typical_phrases[0] # Use first typical phrase
|
|
else:
|
|
greeting = "Hello!"
|
|
|
|
# Generate audio
|
|
gender = personality.gender.value if personality else "female"
|
|
personality_name = personality.name if personality else "Character"
|
|
|
|
audio_content = await self.tts_service.synthesize_speech(greeting, gender, personality_name)
|
|
audio_base64 = base64.b64encode(audio_content).decode('utf-8')
|
|
|
|
return {
|
|
"type": "ai_response",
|
|
"text": greeting,
|
|
"audio": audio_base64,
|
|
"audio_format": "mp3",
|
|
"character": personality_name,
|
|
"is_initial_greeting": True
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"type": "error",
|
|
"message": f"Initial greeting error: {str(e)}"
|
|
}
|
|
|
|
async def process_conversation_flow_fast(self, transcribed_text: str, scenario_context: str = "") -> Dict[str, Any]:
|
|
"""Fast conversation flow with parallel processing."""
|
|
try:
|
|
scenario = self.extract_scenario_from_context(scenario_context)
|
|
if scenario:
|
|
self.set_scenario_personality(scenario)
|
|
|
|
# Get personality info early
|
|
gender = self.ai_service.current_personality.gender.value if self.ai_service.current_personality else "female"
|
|
personality_name = self.ai_service.current_personality.name if self.ai_service.current_personality else "Unknown"
|
|
|
|
# Start AI response generation and goal checking in parallel
|
|
ai_task = asyncio.create_task(self.ai_service.get_response(transcribed_text, scenario_context))
|
|
goal_task = asyncio.create_task(self.ai_service.get_goal_status_async())
|
|
|
|
# Wait for AI response
|
|
ai_response = await ai_task
|
|
|
|
# Start TTS immediately while goal processing might still be running
|
|
tts_task = asyncio.create_task(self.tts_service.synthesize_speech(ai_response, gender, personality_name))
|
|
|
|
# Get goal status (might already be done)
|
|
goal_status = await goal_task
|
|
|
|
# Wait for TTS to complete
|
|
audio_content = await tts_task
|
|
audio_base64 = base64.b64encode(audio_content).decode('utf-8')
|
|
|
|
return {
|
|
"type": "ai_response",
|
|
"text": ai_response,
|
|
"audio": audio_base64,
|
|
"audio_format": "mp3",
|
|
"character": personality_name,
|
|
"goal_status": goal_status,
|
|
"conversation_complete": goal_status.get("all_completed", False)
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"type": "error",
|
|
"message": f"Conversation flow error: {str(e)}"
|
|
}
|
|
|
|
async def process_conversation_flow(self, transcribed_text: str, scenario_context: str = "") -> Dict[str, Any]:
|
|
"""Process the complete conversation flow: Text → AI → Speech."""
|
|
try:
|
|
scenario = self.extract_scenario_from_context(scenario_context)
|
|
if scenario:
|
|
self.set_scenario_personality(scenario)
|
|
|
|
ai_response = await self.ai_service.get_response(transcribed_text, scenario_context)
|
|
|
|
gender = self.ai_service.current_personality.gender.value if self.ai_service.current_personality else "female"
|
|
personality_name = self.ai_service.current_personality.name if self.ai_service.current_personality else "Unknown"
|
|
|
|
audio_content = await self.tts_service.synthesize_speech(ai_response, gender, personality_name)
|
|
audio_base64 = base64.b64encode(audio_content).decode('utf-8')
|
|
|
|
goal_status = self.ai_service.get_goal_status()
|
|
|
|
return {
|
|
"type": "ai_response",
|
|
"text": ai_response,
|
|
"audio": audio_base64,
|
|
"audio_format": "mp3",
|
|
"character": personality_name,
|
|
"goal_status": goal_status,
|
|
"conversation_complete": goal_status.get("all_completed", False)
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"type": "error",
|
|
"message": f"Conversation flow error: {str(e)}"
|
|
}
|
|
|
|
def extract_scenario_from_context(self, context: str) -> str:
|
|
"""Override in language-specific implementations"""
|
|
return "default" |