Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions firebase_admin/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from base64 import b64encode
from typing import Any, Optional, Dict
from dataclasses import dataclass

from google.auth.compute_engine import Credentials as ComputeEngineCredentials
from google.auth.credentials import TokenState
from google.auth.exceptions import RefreshError
from google.auth.transport import requests as google_auth_requests

import requests
import firebase_admin
Expand Down Expand Up @@ -285,14 +289,22 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str
# Get function url from task or generate from resources
if not _Validators.is_non_empty_string(task.http_request['url']):
task.http_request['url'] = self._get_url(resource, _FIREBASE_FUNCTION_URL_FORMAT)

# Refresh the credential to ensure all attributes (e.g. service_account_email, id_token)
# are populated, preventing cold start errors.
if self._credential.token_state != TokenState.FRESH:
try:
self._credential.refresh(google_auth_requests.Request())
except RefreshError as err:
raise ValueError(f'Initial task payload credential refresh failed: {err}') from err

# If extension id is provided, it emplies that it is being run from a deployed extension.
# Meaning that it's credential should be a Compute Engine Credential.
if _Validators.is_non_empty_string(extension_id) and \
isinstance(self._credential, ComputeEngineCredentials):

id_token = self._credential.token
task.http_request['headers'] = \
{**task.http_request['headers'], 'Authorization': f'Bearer ${id_token}'}
{**task.http_request['headers'], 'Authorization': f'Bearer {id_token}'}
# Delete oidc token
del task.http_request['oidc_token']
else:
Expand Down
57 changes: 57 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def test_task_enqueue(self):
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'}
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}

def test_task_enqueue_with_extension(self):
resource_name = (
'projects/test-project/locations/us-central1/queues/'
Expand All @@ -142,6 +146,59 @@ def test_task_enqueue_with_extension(self):
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'}
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}

def test_task_enqueue_compute_engine(self):
app = firebase_admin.initialize_app(
testutils.MockComputeEngineCredential(),
options={'projectId': 'test-project'},
name='test-project-gce')
_, recorder = self._instrument_functions_service(app)
queue = functions.task_queue('test-function-name', app=app)
task_id = queue.enqueue(_DEFAULT_DATA)
assert len(recorder) == 1
assert recorder[0].method == 'POST'
assert recorder[0].url == _DEFAULT_REQUEST_URL
assert recorder[0].headers['Content-Type'] == 'application/json'
assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token'
expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag'
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-gce-email'}
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}

def test_task_enqueue_with_extension_compute_engine(self):
resource_name = (
'projects/test-project/locations/us-central1/queues/'
'ext-test-extension-id-test-function-name/tasks'
)
extension_response = json.dumps({'name': resource_name + '/test-task-id'})
app = firebase_admin.initialize_app(
testutils.MockComputeEngineCredential(),
options={'projectId': 'test-project'},
name='test-project-gce-extensions')
_, recorder = self._instrument_functions_service(app, payload=extension_response)
queue = functions.task_queue('test-function-name', 'test-extension-id', app)
task_id = queue.enqueue(_DEFAULT_DATA)
assert len(recorder) == 1
assert recorder[0].method == 'POST'
assert recorder[0].url == _CLOUD_TASKS_URL + resource_name
assert recorder[0].headers['Content-Type'] == 'application/json'
assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token'
expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag'
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert 'oidc_token' not in task['http_request']
assert task['http_request']['headers'] == {
'Content-Type': 'application/json',
'Authorization': 'Bearer mock-compute-engine-token'}

def test_task_delete(self):
_, recorder = self._instrument_functions_service()
queue = functions.task_queue('test-function-name')
Expand Down
31 changes: 30 additions & 1 deletion tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,25 @@ def __call__(self, *args, **kwargs): # pylint: disable=arguments-differ
# pylint: disable=abstract-method
class MockGoogleCredential(credentials.Credentials):
"""A mock Google authentication credential."""

def __init__(self):
super().__init__()
self.token = None
self._service_account_email = None
self._token_state = credentials.TokenState.INVALID

def refresh(self, request):
self.token = 'mock-token'
self._service_account_email = 'mock-email'
self._token_state = credentials.TokenState.FRESH

@property
def token_state(self):
return self._token_state

@property
def service_account_email(self):
return 'mock-email'
return self._service_account_email

# Simulate x-goog-api-client modification in credential refresh
def _metric_header_for_usage(self):
Expand All @@ -139,8 +152,24 @@ def get_credential(self):

class MockGoogleComputeEngineCredential(compute_engine.Credentials):
"""A mock Compute Engine credential"""

def __init__(self):
super().__init__()
self.token = None
self._service_account_email = None
self._token_state = credentials.TokenState.INVALID

def refresh(self, request):
self.token = 'mock-compute-engine-token'
self._service_account_email = 'mock-gce-email'
self._token_state = credentials.TokenState.FRESH

@property
def token_state(self):
return self._token_state

def _metric_header_for_usage(self):
return 'mock-gce-cred-metric-tag'

class MockComputeEngineCredential(firebase_admin.credentials.Base):
"""A mock Firebase credential implementation."""
Expand Down